diff --git a/.claude/skills/pr-polish/SKILL.md b/.claude/skills/pr-polish/SKILL.md new file mode 100644 index 0000000000..3b36adee14 --- /dev/null +++ b/.claude/skills/pr-polish/SKILL.md @@ -0,0 +1,245 @@ +--- +name: pr-polish +description: Alternate /pr-review and /pr-address on a PR until the PR is truly mergeable — no new review findings, zero unresolved inline threads, zero unaddressed top-level reviews or issue comments, all CI checks green, and two consecutive quiet polls after CI settles. Use when the user wants a PR polished to merge-ready without setting a fixed number of rounds. +user-invocable: true +argument-hint: "[PR number or URL] — if omitted, finds PR for current branch." +metadata: + author: autogpt-team + version: "1.0.0" +--- + +# PR Polish + +**Goal.** Drive a PR to merge-ready by alternating `/pr-review` and `/pr-address` until **all** of the following hold: + +1. The most recent `/pr-review` produces **zero new findings** (no new inline comments, no new top-level reviews with a non-empty body). +2. Every inline review thread reachable via GraphQL reports `isResolved: true`. +3. Every non-bot, non-author top-level review has been acknowledged (replied-to) OR resolved via a thread it spawned. +4. Every non-bot, non-author issue comment has been acknowledged (replied-to). +5. Every CI check is `conclusion: "success"` or `"skipped"` / `"neutral"` — none `"failure"` or still pending. +6. **Two consecutive post-CI polls** (≥60s apart) stay clean — no new threads, no new non-empty reviews, no new issue comments. Bots (coderabbitai, sentry, autogpt-reviewer) frequently post late after CI settles; a single green snapshot is not sufficient. + +**Do not stop at a fixed number of rounds.** If round N introduces new comments, round N+1 is required. Cap at `_MAX_ROUNDS = 10` as a safety valve, but expect 2–5 in practice. + +## TodoWrite + +Before starting, write two todos so the user can see the loop progression: + +- `Round {current}: /pr-review + /pr-address on PR #{N}` — current iteration. +- `Final polish polling: 2 consecutive clean polls, CI green, 0 unresolved` — runs after the last non-empty review round. + +Update the `current` round counter at the start of each iteration; mark `completed` only when the round's address step finishes (all new threads addressed + resolved). + +## Find the PR + +```bash +ARG_PR="${ARG:-}" +# Normalize URL → numeric ID if the skill arg is a pull-request URL. +if [[ "$ARG_PR" =~ ^https?://github\.com/[^/]+/[^/]+/pull/([0-9]+) ]]; then + ARG_PR="${BASH_REMATCH[1]}" +fi +PR="${ARG_PR:-$(gh pr list --head "$(git branch --show-current)" --repo Significant-Gravitas/AutoGPT --json number --jq '.[0].number')}" +if [ -z "$PR" ] || [ "$PR" = "null" ]; then + echo "No PR found for current branch. Provide a PR number or URL as the skill arg." + exit 1 +fi +echo "Polishing PR #$PR" +``` + +## The outer loop + +```text +round = 0 +while round < _MAX_ROUNDS: + round += 1 + baseline = snapshot_state(PR) # see "Snapshotting state" below + invoke_skill("pr-review", PR) # posts findings as inline comments / top-level review + findings = diff_state(PR, baseline) + if findings.total == 0: + break # no new findings → go to polish polling + invoke_skill("pr-address", PR) # resolves every unresolved thread + CI failure +# Post-loop: polish polling (see below). +polish_polling(PR) +``` + +### Snapshotting state + +Before each `/pr-review`, capture a baseline so the diff after the review reflects **only** what the review just added (not pre-existing threads): + +```bash +# Inline threads — total count + latest databaseId per thread +gh api graphql -f query=" +{ + repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") { + pullRequest(number: ${PR}) { + reviewThreads(first: 100) { + totalCount + nodes { + id + isResolved + comments(last: 1) { nodes { databaseId } } + } + } + } + } +}" > /tmp/baseline_threads.json + +# Top-level reviews — count + latest id per non-empty review +gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}/reviews" --paginate \ + --jq '[.[] | select((.body // "") != "") | {id, user: .user.login, state, submitted_at}]' \ + > /tmp/baseline_reviews.json + +# Issue comments — count + latest id per non-bot, non-author comment. +# Bots are filtered by User.type == "Bot" (GitHub sets this for app/bot +# accounts like coderabbitai, github-actions, sentry-io). The author is +# filtered by comparing login to the PR author — export it so jq can see it. +AUTHOR=$(gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}" --jq '.user.login') +gh api "repos/Significant-Gravitas/AutoGPT/issues/${PR}/comments" --paginate \ + --jq --arg author "$AUTHOR" \ + '[.[] | select(.user.type != "Bot" and .user.login != $author) + | {id, user: .user.login, created_at}]' \ + > /tmp/baseline_issue_comments.json +``` + +### Diffing after a review + +After `/pr-review` runs, any of these counting as "new findings" means another address round is needed: + +- New inline thread `id` not in the baseline. +- An existing thread whose latest comment `databaseId` is higher than the baseline's (new reply on an old thread). +- A new top-level review `id` with a non-empty body. +- A new issue comment `id` from a non-bot, non-author user. + +If any of the four buckets is non-empty → not done; invoke `/pr-address` and loop. + +## Polish polling + +Once `/pr-review` produces zero new findings, do **not** exit yet. Bots (coderabbitai, sentry, autogpt-reviewer) commonly post late reviews after CI settles — 30–90 seconds after the final push. Poll at 60-second intervals: + +```text +NON_SUCCESS_TERMINAL = {"failure", "cancelled", "timed_out", "action_required", "startup_failure"} +clean_polls = 0 +required_clean = 2 +while clean_polls < required_clean: + # 1. CI gate — any terminal non-success conclusion (not just "failure") + # must trigger /pr-address. "success", "skipped", "neutral" are clean; + # anything else (including cancelled, timed_out, action_required) is a + # blocker that won't self-resolve. + ci = fetch_check_runs(PR) + if any ci.conclusion in NON_SUCCESS_TERMINAL: + invoke_skill("pr-address", PR) # address failures + any new comments + baseline = snapshot_state(PR) # reset — push during address invalidates old baseline + clean_polls = 0 + continue + if any ci.conclusion is None (still in_progress): + sleep 60; continue # wait without counting this as clean + + # 2. Comment / thread gate + threads = fetch_unresolved_threads(PR) + new_issue_comments = diff_against_baseline(issue_comments) + new_reviews = diff_against_baseline(reviews) + if threads or new_issue_comments or new_reviews: + invoke_skill("pr-address", PR) + baseline = snapshot_state(PR) # reset — the address loop just dealt with these, + # otherwise they stay "new" relative to the old baseline forever + clean_polls = 0 + continue + + # 3. Mergeability gate + mergeable = gh api repos/.../pulls/${PR} --jq '.mergeable' + if mergeable == false (CONFLICTING): + resolve_conflicts(PR) # see pr-address skill + clean_polls = 0 + continue + if mergeable is null (UNKNOWN): + sleep 60; continue + + clean_polls += 1 + sleep 60 +``` + +Only after `clean_polls == 2` do you report `ORCHESTRATOR:DONE`. + +### Why 2 clean polls, not 1 + +A single green snapshot can be misleading — the final CI check often completes ~30s before a bot posts its delayed review. One quiet cycle does not prove the PR is stable; two consecutive cycles with no new threads, reviews, or issue comments arriving gives high confidence nothing else is incoming. + +### Why checking every source each poll + +`/pr-address` polling inside a single round already re-checks its own comments, but `/pr-polish` sits a level above and must also catch: + +- New top-level reviews (autogpt-reviewer sometimes posts structured feedback only after several CI green cycles). +- Issue comments from human reviewers (not caught by inline thread polling). +- Sentry bug predictions that land on new line numbers post-push. +- Merge conflicts introduced by a race between your push and a merge to `dev`. + +## Invocation pattern + +Delegate to existing skills with the `Skill` tool; do not re-implement the review or address logic inline. This keeps the polish loop focused on orchestration and lets the child skills evolve independently. + +```python +Skill(skill="pr-review", args=pr_url) +Skill(skill="pr-address", args=pr_url) +``` + +After each child invocation, re-query GitHub state directly — never trust a summary for the stop condition. The orchestrator's `ORCHESTRATOR:DONE` is verified against actual GraphQL / REST responses per the rules in `pr-address`'s "Verify actual count before outputting ORCHESTRATOR:DONE" section. + +### **Auto-continue: do NOT end your response between child skills** + +`/pr-polish` is a single orchestration task — one invocation drives the PR all the way to merge-ready. When a child `Skill()` call returns control to you: + +- Do NOT summarize and stop. +- Do NOT wait for user confirmation to continue. +- Immediately, in the same response, perform the next loop step: state diff → decide next action → next `Skill()` call or polling sleep. + +The child skill returning is a **loop iteration boundary**, not a conversation turn boundary. You are expected to keep going until one of the exit conditions in the opening section is met (2 consecutive clean polls, `_MAX_ROUNDS` hit, or an unrecoverable error). + +If the user needs to approve a risky action mid-loop (e.g., a force-push or a destructive git operation), pause there — but not at the routine "round N finished, round N+1 needed" boundary. Those are silent transitions. + +## GitHub rate limits + +This skill issues many GraphQL calls (one review-thread query per outer iteration plus per-poll queries inside polish polling). Expect the GraphQL budget to be tight on large PRs. When `gh api rate_limit --jq .resources.graphql.remaining` drops below ~200, back off: + +- Fall back to REST for reads (flat `/pulls/{N}/comments`, `/pulls/{N}/reviews`, `/issues/{N}/comments`) per the `pr-address` skill's GraphQL-fallback section. +- Queue thread resolutions (GraphQL-only) until the budget resets; keep making progress on fixes + REST replies meanwhile. +- `sleep 5` between any batch of ≥20 writes to avoid secondary rate limits. + +## Safety valves + +- `_MAX_ROUNDS = 10` — if review+address rounds exceed this, stop and escalate to the user with a summary of what's still unresolved. A PR that cannot converge in 10 rounds has systemic issues that need human judgment. +- After each commit, run `poetry run format` / `pnpm format && pnpm lint && pnpm types` per the target codebase's conventions. A failing format check is CI `failure` that will never self-resolve. +- Every `/pr-review` round checks for **duplicate** concerns first (via `pr-review`'s own "Fetch existing review comments" step) so the loop does not re-post the same finding that a prior round already resolved. + +## Reporting + +When the skill finishes (either via two clean polls or hitting `_MAX_ROUNDS`), produce a compact summary: + +``` +PR #{N} polish complete ({rounds_completed} rounds): +- {X} inline threads opened and resolved +- {Y} CI failures fixed +- {Z} new commits pushed +Final state: CI green, {total} threads all resolved, mergeable. +``` + +If exiting via `_MAX_ROUNDS`, flag explicitly: + +``` +PR #{N} polish stopped at {_MAX_ROUNDS} rounds — NOT merge-ready: +- {N} threads still unresolved: {titles} +- CI status: {summary} +Needs human review. +``` + +## When to use this skill + +Use when the user says any of: +- "polish this PR" +- "keep reviewing and addressing until it's mergeable" +- "loop /pr-review + /pr-address until done" +- "make sure the PR is actually merge-ready" + +Do **not** use when: +- User wants just one review pass (→ `/pr-review`). +- User wants to address already-posted comments without further self-review (→ `/pr-address`). +- A fixed round count is explicitly requested (e.g., "do 3 rounds") — honour the count instead of converging. diff --git a/.claude/skills/pr-test/SKILL.md b/.claude/skills/pr-test/SKILL.md index b368fb7f0d..0bea79ee03 100644 --- a/.claude/skills/pr-test/SKILL.md +++ b/.claude/skills/pr-test/SKILL.md @@ -186,7 +186,7 @@ Multiple worktrees share the same host — Docker infra (postgres, redis, clamav ### Lock file contract -Path (**always** the root worktree so all siblings see it): `/Users/majdyz/Code/AutoGPT/.ign.testing.lock` +Path (**always** the root worktree so all siblings see it): `$REPO_ROOT/.ign.testing.lock` Body (one `key=value` per line): ``` @@ -202,7 +202,7 @@ intent= ### Claim ```bash -LOCK=/Users/majdyz/Code/AutoGPT/.ign.testing.lock +LOCK=$REPO_ROOT/.ign.testing.lock NOW=$(date -u +%Y-%m-%dT%H:%MZ) STALE_AFTER_MIN=5 @@ -252,7 +252,7 @@ echo "$HEARTBEAT_PID" > /tmp/pr-test-heartbeat.pid kill "$HEARTBEAT_PID" 2>/dev/null rm -f "$LOCK" /tmp/pr-test-heartbeat.pid echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock" \ - >> /Users/majdyz/Code/AutoGPT/.ign.testing.log + >> $REPO_ROOT/.ign.testing.log ``` Use a `trap` so release runs even on `exit 1`: @@ -260,12 +260,38 @@ Use a `trap` so release runs even on `exit 1`: trap 'kill "$HEARTBEAT_PID" 2>/dev/null; rm -f "$LOCK"' EXIT INT TERM ``` +### **Release the lock AS SOON AS the test run is done** + +The lock guards **test execution**, not **app lifecycle**. Once Step 5 (record results) and Step 6 (post PR comment) are complete, release the lock IMMEDIATELY — even if: + +- The native `poetry run app` / `pnpm dev` processes are still running so the user can keep poking at the app manually. +- You're leaving docker containers up. +- You're tailing logs for a minute or two. + +Keeping the lock held past the test run is the single most common way `/pr-test` stalls other agents. **The app staying up is orthogonal to the lock; don't conflate them.** Sibling worktrees running their own `/pr-test` will kill the stray processes and free the ports themselves (Step 3c/3e-native handle that) — they just need the lock file gone. + +Concretely, the sequence at the end of every `/pr-test` run (success or failure) is: + +```bash +# 1. Write the final report + post PR comment — done above in Step 5/6. +# 2. Release the lock right now, even if the app is still up. +kill "$HEARTBEAT_PID" 2>/dev/null +rm -f "$LOCK" /tmp/pr-test-heartbeat.pid +echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock (app may still be running)" \ + >> $REPO_ROOT/.ign.testing.log +# 3. Optionally leave the app running and note it so the user knows: +echo "Native stack still running on :3000 / :8006 for manual poking. Kill with:" +echo " pkill -9 -f 'poetry run app'; pkill -9 -f 'next-server|next dev'" +``` + +If a sibling agent's `/pr-test` needs to take over, it'll do the kill+rebuild dance from Step 3c/3e-native on its own — your only job is to not hold the lock file past the end of your test. + ### Shared status log -`/Users/majdyz/Code/AutoGPT/.ign.testing.log` is an append-only channel any agent can read/write. Use it for "I'm waiting", "I'm done, resources free", or post-run notes: +`$REPO_ROOT/.ign.testing.log` is an append-only channel any agent can read/write. Use it for "I'm waiting", "I'm done, resources free", or post-run notes: ```bash echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] " \ - >> /Users/majdyz/Code/AutoGPT/.ign.testing.log + >> $REPO_ROOT/.ign.testing.log ``` ## Step 3: Environment setup @@ -755,6 +781,19 @@ Upload screenshots to the PR using the GitHub Git API (no local git operations **CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.** Every screenshot MUST appear as `![name](raw_url)` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes. +**CRITICAL — NEVER paste absolute local paths into the PR comment.** Strings like `/Users/…`, `/home/…`, `C:\…` are useless to every reviewer except you. Before posting, grep the final body for `/Users/`, `/home/`, `/tmp/`, `/private/`, `C:\`, `~/` and either drop those lines entirely or rewrite them as repo-relative paths (`autogpt_platform/backend/…`). The PR comment is an artifact reviewers on GitHub read — it must be self-contained on github.com. Keep local paths in `$RESULTS_DIR/test-report.md` for yourself; only copy the *content* they reference (excerpts, test names, log lines) into the PR comment, not the path. + +**Pre-post sanity check** (paste after building the comment body, before `gh api ... comments`): + +```bash +# Reject any local-looking absolute path or home-dir shortcut in the body +if grep -nE '(^|[^A-Za-z])(/Users/|/home/|/tmp/|/private/|C:\\|~/)[A-Za-z0-9]' "$COMMENT_FILE" ; then + echo "ABORT: local filesystem paths detected in PR comment body." + echo "Remove or rewrite as repo-relative (autogpt_platform/...) before posting." + exit 1 +fi +``` + ```bash # Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely) REPO="Significant-Gravitas/AutoGPT" diff --git a/.gitignore b/.gitignore index 97d6b18a76..53df57dc70 100644 --- a/.gitignore +++ b/.gitignore @@ -195,3 +195,4 @@ test.db # Implementation plans (generated by AI agents) plans/ .claude/worktrees/ +test-results/ diff --git a/autogpt_platform/.gitignore b/autogpt_platform/.gitignore index 3e31a9970e..bc70dc96bc 100644 --- a/autogpt_platform/.gitignore +++ b/autogpt_platform/.gitignore @@ -1,3 +1,6 @@ *.ignore.* *.ign.* .application.logs + +# Claude Code local settings only — the rest of .claude/ is shared (skills etc.) +.claude/settings.local.json diff --git a/autogpt_platform/autogpt_libs/autogpt_libs/supabase_integration_credentials_store/types.py b/autogpt_platform/autogpt_libs/autogpt_libs/supabase_integration_credentials_store/types.py index 04c6fa2a77..eb69ab2fac 100644 --- a/autogpt_platform/autogpt_libs/autogpt_libs/supabase_integration_credentials_store/types.py +++ b/autogpt_platform/autogpt_libs/autogpt_libs/supabase_integration_credentials_store/types.py @@ -59,6 +59,8 @@ class OAuthState(BaseModel): code_verifier: Optional[str] = None scopes: list[str] """Unix timestamp (seconds) indicating when this OAuth state expires""" + credential_id: Optional[str] = None + """If set, this OAuth flow upgrades an existing credential's scopes.""" class UserMetadata(BaseModel): diff --git a/autogpt_platform/backend/.env.default b/autogpt_platform/backend/.env.default index e731f9f9bf..67444c2e36 100644 --- a/autogpt_platform/backend/.env.default +++ b/autogpt_platform/backend/.env.default @@ -179,6 +179,9 @@ MEM0_API_KEY= OPENWEATHERMAP_API_KEY= GOOGLE_MAPS_API_KEY= +# Platform Bot Linking +PLATFORM_LINK_BASE_URL=http://localhost:3000/link + # Communication Services DISCORD_BOT_TOKEN= MEDIUM_API_KEY= diff --git a/autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes.py b/autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes.py new file mode 100644 index 0000000000..4cb8ff0729 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes.py @@ -0,0 +1,932 @@ +import asyncio +import logging +from typing import List + +from autogpt_libs.auth import requires_admin_user +from autogpt_libs.auth.models import User as AuthUser +from fastapi import APIRouter, HTTPException, Security +from prisma.enums import AgentExecutionStatus +from pydantic import BaseModel + +from backend.api.features.admin.model import ( + AgentDiagnosticsResponse, + ExecutionDiagnosticsResponse, +) +from backend.data.diagnostics import ( + FailedExecutionDetail, + OrphanedScheduleDetail, + RunningExecutionDetail, + ScheduleDetail, + ScheduleHealthMetrics, + cleanup_all_stuck_queued_executions, + cleanup_orphaned_executions_bulk, + cleanup_orphaned_schedules_bulk, + get_agent_diagnostics, + get_all_orphaned_execution_ids, + get_all_schedules_details, + get_all_stuck_queued_execution_ids, + get_execution_diagnostics, + get_failed_executions_count, + get_failed_executions_details, + get_invalid_executions_details, + get_long_running_executions_details, + get_orphaned_executions_details, + get_orphaned_schedules_details, + get_running_executions_details, + get_schedule_health_metrics, + get_stuck_queued_executions_details, + stop_all_long_running_executions, +) +from backend.data.execution import get_graph_executions +from backend.executor.utils import add_graph_execution, stop_graph_execution + +logger = logging.getLogger(__name__) + +router = APIRouter( + prefix="/admin", + tags=["diagnostics", "admin"], + dependencies=[Security(requires_admin_user)], +) + + +class RunningExecutionsListResponse(BaseModel): + """Response model for list of running executions""" + + executions: List[RunningExecutionDetail] + total: int + + +class FailedExecutionsListResponse(BaseModel): + """Response model for list of failed executions""" + + executions: List[FailedExecutionDetail] + total: int + + +class StopExecutionRequest(BaseModel): + """Request model for stopping a single execution""" + + execution_id: str + + +class StopExecutionsRequest(BaseModel): + """Request model for stopping multiple executions""" + + execution_ids: List[str] + + +class StopExecutionResponse(BaseModel): + """Response model for stop execution operations""" + + success: bool + stopped_count: int = 0 + message: str + + +class RequeueExecutionResponse(BaseModel): + """Response model for requeue execution operations""" + + success: bool + requeued_count: int = 0 + message: str + + +@router.get( + "/diagnostics/executions", + response_model=ExecutionDiagnosticsResponse, + summary="Get Execution Diagnostics", +) +async def get_execution_diagnostics_endpoint(): + """ + Get comprehensive diagnostic information about execution status. + + Returns all execution metrics including: + - Current state (running, queued) + - Orphaned executions (>24h old, likely not in executor) + - Failure metrics (1h, 24h, rate) + - Long-running detection (stuck >1h, >24h) + - Stuck queued detection + - Throughput metrics (completions/hour) + - RabbitMQ queue depths + """ + logger.info("Getting execution diagnostics") + + diagnostics = await get_execution_diagnostics() + + response = ExecutionDiagnosticsResponse( + running_executions=diagnostics.running_count, + queued_executions_db=diagnostics.queued_db_count, + queued_executions_rabbitmq=diagnostics.rabbitmq_queue_depth, + cancel_queue_depth=diagnostics.cancel_queue_depth, + orphaned_running=diagnostics.orphaned_running, + orphaned_queued=diagnostics.orphaned_queued, + failed_count_1h=diagnostics.failed_count_1h, + failed_count_24h=diagnostics.failed_count_24h, + failure_rate_24h=diagnostics.failure_rate_24h, + stuck_running_24h=diagnostics.stuck_running_24h, + stuck_running_1h=diagnostics.stuck_running_1h, + oldest_running_hours=diagnostics.oldest_running_hours, + stuck_queued_1h=diagnostics.stuck_queued_1h, + queued_never_started=diagnostics.queued_never_started, + invalid_queued_with_start=diagnostics.invalid_queued_with_start, + invalid_running_without_start=diagnostics.invalid_running_without_start, + completed_1h=diagnostics.completed_1h, + completed_24h=diagnostics.completed_24h, + throughput_per_hour=diagnostics.throughput_per_hour, + timestamp=diagnostics.timestamp, + ) + + logger.info( + f"Execution diagnostics: running={diagnostics.running_count}, " + f"queued_db={diagnostics.queued_db_count}, " + f"orphaned={diagnostics.orphaned_running + diagnostics.orphaned_queued}, " + f"failed_24h={diagnostics.failed_count_24h}" + ) + + return response + + +@router.get( + "/diagnostics/agents", + response_model=AgentDiagnosticsResponse, + summary="Get Agent Diagnostics", +) +async def get_agent_diagnostics_endpoint(): + """ + Get diagnostic information about agents. + + Returns: + - agents_with_active_executions: Number of unique agents with running/queued executions + - timestamp: Current timestamp + """ + logger.info("Getting agent diagnostics") + + diagnostics = await get_agent_diagnostics() + + response = AgentDiagnosticsResponse( + agents_with_active_executions=diagnostics.agents_with_active_executions, + timestamp=diagnostics.timestamp, + ) + + logger.info( + f"Agent diagnostics: with_active_executions={diagnostics.agents_with_active_executions}" + ) + + return response + + +@router.get( + "/diagnostics/executions/running", + response_model=RunningExecutionsListResponse, + summary="List Running Executions", +) +async def list_running_executions( + limit: int = 100, + offset: int = 0, +): + """ + Get detailed list of running and queued executions (recent, likely active). + + Args: + limit: Maximum number of executions to return (default 100) + offset: Number of executions to skip (default 0) + + Returns: + List of running executions with details + """ + logger.info(f"Listing running executions (limit={limit}, offset={offset})") + + executions = await get_running_executions_details(limit=limit, offset=offset) + + # Get total count for pagination + diagnostics = await get_execution_diagnostics() + total = diagnostics.running_count + diagnostics.queued_db_count + + return RunningExecutionsListResponse(executions=executions, total=total) + + +@router.get( + "/diagnostics/executions/orphaned", + response_model=RunningExecutionsListResponse, + summary="List Orphaned Executions", +) +async def list_orphaned_executions( + limit: int = 100, + offset: int = 0, +): + """ + Get detailed list of orphaned executions (>24h old, likely not in executor). + + Args: + limit: Maximum number of executions to return (default 100) + offset: Number of executions to skip (default 0) + + Returns: + List of orphaned executions with details + """ + logger.info(f"Listing orphaned executions (limit={limit}, offset={offset})") + + executions = await get_orphaned_executions_details(limit=limit, offset=offset) + + # Get total count for pagination + diagnostics = await get_execution_diagnostics() + total = diagnostics.orphaned_running + diagnostics.orphaned_queued + + return RunningExecutionsListResponse(executions=executions, total=total) + + +@router.get( + "/diagnostics/executions/failed", + response_model=FailedExecutionsListResponse, + summary="List Failed Executions", +) +async def list_failed_executions( + limit: int = 100, + offset: int = 0, + hours: int = 24, +): + """ + Get detailed list of failed executions. + + Args: + limit: Maximum number of executions to return (default 100) + offset: Number of executions to skip (default 0) + hours: Number of hours to look back (default 24) + + Returns: + List of failed executions with error details + """ + logger.info( + f"Listing failed executions (limit={limit}, offset={offset}, hours={hours})" + ) + + executions = await get_failed_executions_details( + limit=limit, offset=offset, hours=hours + ) + + # Get total count for pagination + # Always count actual total for given hours parameter + total = await get_failed_executions_count(hours=hours) + + return FailedExecutionsListResponse(executions=executions, total=total) + + +@router.get( + "/diagnostics/executions/long-running", + response_model=RunningExecutionsListResponse, + summary="List Long-Running Executions", +) +async def list_long_running_executions( + limit: int = 100, + offset: int = 0, +): + """ + Get detailed list of long-running executions (RUNNING status >24h). + + Args: + limit: Maximum number of executions to return (default 100) + offset: Number of executions to skip (default 0) + + Returns: + List of long-running executions with details + """ + logger.info(f"Listing long-running executions (limit={limit}, offset={offset})") + + executions = await get_long_running_executions_details(limit=limit, offset=offset) + + # Get total count for pagination + diagnostics = await get_execution_diagnostics() + total = diagnostics.stuck_running_24h + + return RunningExecutionsListResponse(executions=executions, total=total) + + +@router.get( + "/diagnostics/executions/stuck-queued", + response_model=RunningExecutionsListResponse, + summary="List Stuck Queued Executions", +) +async def list_stuck_queued_executions( + limit: int = 100, + offset: int = 0, +): + """ + Get detailed list of stuck queued executions (QUEUED >1h, never started). + + Args: + limit: Maximum number of executions to return (default 100) + offset: Number of executions to skip (default 0) + + Returns: + List of stuck queued executions with details + """ + logger.info(f"Listing stuck queued executions (limit={limit}, offset={offset})") + + executions = await get_stuck_queued_executions_details(limit=limit, offset=offset) + + # Get total count for pagination + diagnostics = await get_execution_diagnostics() + total = diagnostics.stuck_queued_1h + + return RunningExecutionsListResponse(executions=executions, total=total) + + +@router.get( + "/diagnostics/executions/invalid", + response_model=RunningExecutionsListResponse, + summary="List Invalid Executions", +) +async def list_invalid_executions( + limit: int = 100, + offset: int = 0, +): + """ + Get detailed list of executions in invalid states (READ-ONLY). + + Invalid states indicate data corruption and require manual investigation: + - QUEUED but has startedAt (impossible - can't start while queued) + - RUNNING but no startedAt (impossible - can't run without starting) + + ⚠️ NO BULK ACTIONS PROVIDED - These need case-by-case investigation. + + Each invalid execution likely has a different root cause (crashes, race conditions, + DB corruption). Investigate the execution history and logs to determine appropriate + action (manual cleanup, status fix, or leave as-is if system recovered). + + Args: + limit: Maximum number of executions to return (default 100) + offset: Number of executions to skip (default 0) + + Returns: + List of invalid state executions with details + """ + logger.info(f"Listing invalid state executions (limit={limit}, offset={offset})") + + executions = await get_invalid_executions_details(limit=limit, offset=offset) + + # Get total count for pagination + diagnostics = await get_execution_diagnostics() + total = ( + diagnostics.invalid_queued_with_start + + diagnostics.invalid_running_without_start + ) + + return RunningExecutionsListResponse(executions=executions, total=total) + + +@router.post( + "/diagnostics/executions/requeue", + response_model=RequeueExecutionResponse, + summary="Requeue Stuck Execution", +) +async def requeue_single_execution( + request: StopExecutionRequest, # Reuse same request model (has execution_id) + user: AuthUser = Security(requires_admin_user), +): + """ + Requeue a stuck QUEUED execution (admin only). + + Uses add_graph_execution with existing graph_exec_id to requeue. + + ⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits. + + Args: + request: Contains execution_id to requeue + + Returns: + Success status and message + """ + logger.info(f"Admin {user.user_id} requeueing execution {request.execution_id}") + + # Get the execution (validation - must be QUEUED) + executions = await get_graph_executions( + graph_exec_id=request.execution_id, + statuses=[AgentExecutionStatus.QUEUED], + ) + + if not executions: + raise HTTPException( + status_code=404, + detail="Execution not found or not in QUEUED status", + ) + + execution = executions[0] + + # Use add_graph_execution in requeue mode + await add_graph_execution( + graph_id=execution.graph_id, + user_id=execution.user_id, + graph_version=execution.graph_version, + graph_exec_id=request.execution_id, # Requeue existing execution + ) + + return RequeueExecutionResponse( + success=True, + requeued_count=1, + message="Execution requeued successfully", + ) + + +@router.post( + "/diagnostics/executions/requeue-bulk", + response_model=RequeueExecutionResponse, + summary="Requeue Multiple Stuck Executions", +) +async def requeue_multiple_executions( + request: StopExecutionsRequest, # Reuse same request model (has execution_ids) + user: AuthUser = Security(requires_admin_user), +): + """ + Requeue multiple stuck QUEUED executions (admin only). + + Uses add_graph_execution with existing graph_exec_id to requeue. + + ⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits. + + Args: + request: Contains list of execution_ids to requeue + + Returns: + Number of executions requeued and success message + """ + logger.info( + f"Admin {user.user_id} requeueing {len(request.execution_ids)} executions" + ) + + # Get executions by ID list (must be QUEUED) + executions = await get_graph_executions( + execution_ids=request.execution_ids, + statuses=[AgentExecutionStatus.QUEUED], + ) + + if not executions: + return RequeueExecutionResponse( + success=False, + requeued_count=0, + message="No QUEUED executions found to requeue", + ) + + # Requeue all executions in parallel using add_graph_execution + async def requeue_one(exec) -> bool: + try: + await add_graph_execution( + graph_id=exec.graph_id, + user_id=exec.user_id, + graph_version=exec.graph_version, + graph_exec_id=exec.id, # Requeue existing + ) + return True + except Exception as e: + logger.error(f"Failed to requeue {exec.id}: {e}") + return False + + results = await asyncio.gather( + *[requeue_one(exec) for exec in executions], return_exceptions=False + ) + + requeued_count = sum(1 for success in results if success) + + return RequeueExecutionResponse( + success=requeued_count > 0, + requeued_count=requeued_count, + message=f"Requeued {requeued_count} of {len(request.execution_ids)} executions", + ) + + +@router.post( + "/diagnostics/executions/stop", + response_model=StopExecutionResponse, + summary="Stop Single Execution", +) +async def stop_single_execution( + request: StopExecutionRequest, + user: AuthUser = Security(requires_admin_user), +): + """ + Stop a single execution (admin only). + + Uses robust stop_graph_execution which cascades to children and waits for termination. + + Args: + request: Contains execution_id to stop + + Returns: + Success status and message + """ + logger.info(f"Admin {user.user_id} stopping execution {request.execution_id}") + + # Get the execution to find its owner user_id (required by stop_graph_execution) + executions = await get_graph_executions( + graph_exec_id=request.execution_id, + ) + + if not executions: + raise HTTPException(status_code=404, detail="Execution not found") + + execution = executions[0] + + # Use robust stop_graph_execution (cascades to children, waits for termination) + await stop_graph_execution( + user_id=execution.user_id, + graph_exec_id=request.execution_id, + wait_timeout=15.0, + cascade=True, + ) + + return StopExecutionResponse( + success=True, + stopped_count=1, + message="Execution stopped successfully", + ) + + +@router.post( + "/diagnostics/executions/stop-bulk", + response_model=StopExecutionResponse, + summary="Stop Multiple Executions", +) +async def stop_multiple_executions( + request: StopExecutionsRequest, + user: AuthUser = Security(requires_admin_user), +): + """ + Stop multiple active executions (admin only). + + Uses robust stop_graph_execution which cascades to children and waits for termination. + + Args: + request: Contains list of execution_ids to stop + + Returns: + Number of executions stopped and success message + """ + + logger.info( + f"Admin {user.user_id} stopping {len(request.execution_ids)} executions" + ) + + # Get executions by ID list + executions = await get_graph_executions( + execution_ids=request.execution_ids, + ) + + if not executions: + return StopExecutionResponse( + success=False, + stopped_count=0, + message="No executions found", + ) + + # Stop all executions in parallel using robust stop_graph_execution + async def stop_one(exec) -> bool: + try: + await stop_graph_execution( + user_id=exec.user_id, + graph_exec_id=exec.id, + wait_timeout=15.0, + cascade=True, + ) + return True + except Exception as e: + logger.error(f"Failed to stop execution {exec.id}: {e}") + return False + + results = await asyncio.gather( + *[stop_one(exec) for exec in executions], return_exceptions=False + ) + + stopped_count = sum(1 for success in results if success) + + return StopExecutionResponse( + success=stopped_count > 0, + stopped_count=stopped_count, + message=f"Stopped {stopped_count} of {len(request.execution_ids)} executions", + ) + + +@router.post( + "/diagnostics/executions/cleanup-orphaned", + response_model=StopExecutionResponse, + summary="Cleanup Orphaned Executions", +) +async def cleanup_orphaned_executions( + request: StopExecutionsRequest, + user: AuthUser = Security(requires_admin_user), +): + """ + Cleanup orphaned executions by directly updating DB status (admin only). + For executions in DB but not actually running in executor (old/stale records). + + Args: + request: Contains list of execution_ids to cleanup + + Returns: + Number of executions cleaned up and success message + """ + logger.info( + f"Admin {user.user_id} cleaning up {len(request.execution_ids)} orphaned executions" + ) + + cleaned_count = await cleanup_orphaned_executions_bulk( + request.execution_ids, user.user_id + ) + + return StopExecutionResponse( + success=cleaned_count > 0, + stopped_count=cleaned_count, + message=f"Cleaned up {cleaned_count} of {len(request.execution_ids)} orphaned executions", + ) + + +# ============================================================================ +# SCHEDULE DIAGNOSTICS ENDPOINTS +# ============================================================================ + + +class SchedulesListResponse(BaseModel): + """Response model for list of schedules""" + + schedules: List[ScheduleDetail] + total: int + + +class OrphanedSchedulesListResponse(BaseModel): + """Response model for list of orphaned schedules""" + + schedules: List[OrphanedScheduleDetail] + total: int + + +class ScheduleCleanupRequest(BaseModel): + """Request model for cleaning up schedules""" + + schedule_ids: List[str] + + +class ScheduleCleanupResponse(BaseModel): + """Response model for schedule cleanup operations""" + + success: bool + deleted_count: int = 0 + message: str + + +@router.get( + "/diagnostics/schedules", + response_model=ScheduleHealthMetrics, + summary="Get Schedule Diagnostics", +) +async def get_schedule_diagnostics_endpoint(): + """ + Get comprehensive diagnostic information about schedule health. + + Returns schedule metrics including: + - Total schedules (user vs system) + - Orphaned schedules by category + - Upcoming executions + """ + logger.info("Getting schedule diagnostics") + + diagnostics = await get_schedule_health_metrics() + + logger.info( + f"Schedule diagnostics: total={diagnostics.total_schedules}, " + f"user={diagnostics.user_schedules}, " + f"orphaned={diagnostics.total_orphaned}" + ) + + return diagnostics + + +@router.get( + "/diagnostics/schedules/all", + response_model=SchedulesListResponse, + summary="List All User Schedules", +) +async def list_all_schedules( + limit: int = 100, + offset: int = 0, +): + """ + Get detailed list of all user schedules (excludes system monitoring jobs). + + Args: + limit: Maximum number of schedules to return (default 100) + offset: Number of schedules to skip (default 0) + + Returns: + List of schedules with details + """ + logger.info(f"Listing all schedules (limit={limit}, offset={offset})") + + schedules = await get_all_schedules_details(limit=limit, offset=offset) + + # Get total count + diagnostics = await get_schedule_health_metrics() + total = diagnostics.user_schedules + + return SchedulesListResponse(schedules=schedules, total=total) + + +@router.get( + "/diagnostics/schedules/orphaned", + response_model=OrphanedSchedulesListResponse, + summary="List Orphaned Schedules", +) +async def list_orphaned_schedules(): + """ + Get detailed list of orphaned schedules with orphan reasons. + + Returns: + List of orphaned schedules categorized by orphan type + """ + logger.info("Listing orphaned schedules") + + schedules = await get_orphaned_schedules_details() + + return OrphanedSchedulesListResponse(schedules=schedules, total=len(schedules)) + + +@router.post( + "/diagnostics/schedules/cleanup-orphaned", + response_model=ScheduleCleanupResponse, + summary="Cleanup Orphaned Schedules", +) +async def cleanup_orphaned_schedules( + request: ScheduleCleanupRequest, + user: AuthUser = Security(requires_admin_user), +): + """ + Cleanup orphaned schedules by deleting from scheduler (admin only). + + Args: + request: Contains list of schedule_ids to delete + + Returns: + Number of schedules deleted and success message + """ + logger.info( + f"Admin {user.user_id} cleaning up {len(request.schedule_ids)} orphaned schedules" + ) + + deleted_count = await cleanup_orphaned_schedules_bulk( + request.schedule_ids, user.user_id + ) + + return ScheduleCleanupResponse( + success=deleted_count > 0, + deleted_count=deleted_count, + message=f"Deleted {deleted_count} of {len(request.schedule_ids)} orphaned schedules", + ) + + +@router.post( + "/diagnostics/executions/stop-all-long-running", + response_model=StopExecutionResponse, + summary="Stop ALL Long-Running Executions", +) +async def stop_all_long_running_executions_endpoint( + user: AuthUser = Security(requires_admin_user), +): + """ + Stop ALL long-running executions (RUNNING >24h) by sending cancel signals (admin only). + Operates on entire dataset, not limited to pagination. + + Returns: + Number of executions stopped and success message + """ + logger.info(f"Admin {user.user_id} stopping ALL long-running executions") + + stopped_count = await stop_all_long_running_executions(user.user_id) + + return StopExecutionResponse( + success=stopped_count > 0, + stopped_count=stopped_count, + message=f"Stopped {stopped_count} long-running executions", + ) + + +@router.post( + "/diagnostics/executions/cleanup-all-orphaned", + response_model=StopExecutionResponse, + summary="Cleanup ALL Orphaned Executions", +) +async def cleanup_all_orphaned_executions( + user: AuthUser = Security(requires_admin_user), +): + """ + Cleanup ALL orphaned executions (>24h old) by directly updating DB status. + Operates on all executions, not just paginated results. + + Returns: + Number of executions cleaned up and success message + """ + logger.info(f"Admin {user.user_id} cleaning up ALL orphaned executions") + + # Fetch all orphaned execution IDs + execution_ids = await get_all_orphaned_execution_ids() + + if not execution_ids: + return StopExecutionResponse( + success=True, + stopped_count=0, + message="No orphaned executions to cleanup", + ) + + cleaned_count = await cleanup_orphaned_executions_bulk(execution_ids, user.user_id) + + return StopExecutionResponse( + success=cleaned_count > 0, + stopped_count=cleaned_count, + message=f"Cleaned up {cleaned_count} orphaned executions", + ) + + +@router.post( + "/diagnostics/executions/cleanup-all-stuck-queued", + response_model=StopExecutionResponse, + summary="Cleanup ALL Stuck Queued Executions", +) +async def cleanup_all_stuck_queued_executions_endpoint( + user: AuthUser = Security(requires_admin_user), +): + """ + Cleanup ALL stuck queued executions (QUEUED >1h) by updating DB status (admin only). + Operates on entire dataset, not limited to pagination. + + Returns: + Number of executions cleaned up and success message + """ + logger.info(f"Admin {user.user_id} cleaning up ALL stuck queued executions") + + cleaned_count = await cleanup_all_stuck_queued_executions(user.user_id) + + return StopExecutionResponse( + success=cleaned_count > 0, + stopped_count=cleaned_count, + message=f"Cleaned up {cleaned_count} stuck queued executions", + ) + + +@router.post( + "/diagnostics/executions/requeue-all-stuck", + response_model=RequeueExecutionResponse, + summary="Requeue ALL Stuck Queued Executions", +) +async def requeue_all_stuck_executions( + user: AuthUser = Security(requires_admin_user), +): + """ + Requeue ALL stuck queued executions (QUEUED >1h) by publishing to RabbitMQ. + Operates on all executions, not just paginated results. + + Uses add_graph_execution with existing graph_exec_id to requeue. + + ⚠️ WARNING: This will re-execute ALL stuck executions and may cost significant credits. + + Returns: + Number of executions requeued and success message + """ + logger.info(f"Admin {user.user_id} requeueing ALL stuck queued executions") + + # Fetch all stuck queued execution IDs + execution_ids = await get_all_stuck_queued_execution_ids() + + if not execution_ids: + return RequeueExecutionResponse( + success=True, + requeued_count=0, + message="No stuck queued executions to requeue", + ) + + # Get stuck executions by ID list (must be QUEUED) + executions = await get_graph_executions( + execution_ids=execution_ids, + statuses=[AgentExecutionStatus.QUEUED], + ) + + # Requeue all in parallel using add_graph_execution + async def requeue_one(exec) -> bool: + try: + await add_graph_execution( + graph_id=exec.graph_id, + user_id=exec.user_id, + graph_version=exec.graph_version, + graph_exec_id=exec.id, # Requeue existing + ) + return True + except Exception as e: + logger.error(f"Failed to requeue {exec.id}: {e}") + return False + + results = await asyncio.gather( + *[requeue_one(exec) for exec in executions], return_exceptions=False + ) + + requeued_count = sum(1 for success in results if success) + + return RequeueExecutionResponse( + success=requeued_count > 0, + requeued_count=requeued_count, + message=f"Requeued {requeued_count} stuck executions", + ) diff --git a/autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes_test.py b/autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes_test.py new file mode 100644 index 0000000000..a3783312b0 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes_test.py @@ -0,0 +1,889 @@ +from datetime import datetime, timezone +from unittest.mock import AsyncMock + +import fastapi +import fastapi.testclient +import pytest +import pytest_mock +from autogpt_libs.auth.jwt_utils import get_jwt_payload +from prisma.enums import AgentExecutionStatus + +import backend.api.features.admin.diagnostics_admin_routes as diagnostics_admin_routes +from backend.data.diagnostics import ( + AgentDiagnosticsSummary, + ExecutionDiagnosticsSummary, + FailedExecutionDetail, + OrphanedScheduleDetail, + RunningExecutionDetail, + ScheduleDetail, + ScheduleHealthMetrics, +) +from backend.data.execution import GraphExecutionMeta + +app = fastapi.FastAPI() +app.include_router(diagnostics_admin_routes.router) + +client = fastapi.testclient.TestClient(app) + + +@pytest.fixture(autouse=True) +def setup_app_admin_auth(mock_jwt_admin): + """Setup admin auth overrides for all tests in this module""" + app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"] + yield + app.dependency_overrides.clear() + + +def test_get_execution_diagnostics_success( + mocker: pytest_mock.MockFixture, +): + """Test fetching execution diagnostics with invalid state detection""" + mock_diagnostics = ExecutionDiagnosticsSummary( + running_count=10, + queued_db_count=5, + rabbitmq_queue_depth=3, + cancel_queue_depth=0, + orphaned_running=2, + orphaned_queued=1, + failed_count_1h=5, + failed_count_24h=20, + failure_rate_24h=0.83, + stuck_running_24h=1, + stuck_running_1h=3, + oldest_running_hours=26.5, + stuck_queued_1h=2, + queued_never_started=1, + invalid_queued_with_start=1, # New invalid state + invalid_running_without_start=1, # New invalid state + completed_1h=50, + completed_24h=1200, + throughput_per_hour=50.0, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics", + return_value=mock_diagnostics, + ) + + response = client.get("/admin/diagnostics/executions") + + assert response.status_code == 200 + data = response.json() + + # Verify new invalid state fields are included + assert data["invalid_queued_with_start"] == 1 + assert data["invalid_running_without_start"] == 1 + # Verify all expected fields present + assert "running_executions" in data + assert "orphaned_running" in data + assert "failed_count_24h" in data + + +def test_list_invalid_executions( + mocker: pytest_mock.MockFixture, +): + """Test listing executions in invalid states (read-only endpoint)""" + mock_invalid_executions = [ + RunningExecutionDetail( + execution_id="exec-invalid-1", + graph_id="graph-123", + graph_name="Test Graph", + graph_version=1, + user_id="user-123", + user_email="test@example.com", + status="QUEUED", + created_at=datetime.now(timezone.utc), + started_at=datetime.now( + timezone.utc + ), # QUEUED but has startedAt - INVALID! + queue_status=None, + ), + RunningExecutionDetail( + execution_id="exec-invalid-2", + graph_id="graph-456", + graph_name="Another Graph", + graph_version=2, + user_id="user-456", + user_email="user@example.com", + status="RUNNING", + created_at=datetime.now(timezone.utc), + started_at=None, # RUNNING but no startedAt - INVALID! + queue_status=None, + ), + ] + + mock_diagnostics = ExecutionDiagnosticsSummary( + running_count=10, + queued_db_count=5, + rabbitmq_queue_depth=3, + cancel_queue_depth=0, + orphaned_running=0, + orphaned_queued=0, + failed_count_1h=0, + failed_count_24h=0, + failure_rate_24h=0.0, + stuck_running_24h=0, + stuck_running_1h=0, + oldest_running_hours=None, + stuck_queued_1h=0, + queued_never_started=0, + invalid_queued_with_start=1, + invalid_running_without_start=1, + completed_1h=0, + completed_24h=0, + throughput_per_hour=0.0, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_invalid_executions_details", + return_value=mock_invalid_executions, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics", + return_value=mock_diagnostics, + ) + + response = client.get("/admin/diagnostics/executions/invalid?limit=100&offset=0") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 2 # Sum of both invalid state types + assert len(data["executions"]) == 2 + # Verify both types of invalid states are returned + assert data["executions"][0]["execution_id"] in [ + "exec-invalid-1", + "exec-invalid-2", + ] + assert data["executions"][1]["execution_id"] in [ + "exec-invalid-1", + "exec-invalid-2", + ] + + +def test_requeue_single_execution_with_add_graph_execution( + mocker: pytest_mock.MockFixture, + admin_user_id: str, +): + """Test requeueing uses add_graph_execution in requeue mode""" + mock_exec_meta = GraphExecutionMeta( + id="exec-stuck-123", + user_id="user-123", + graph_id="graph-456", + graph_version=1, + inputs=None, + credential_inputs=None, + nodes_input_masks=None, + preset_id=None, + status=AgentExecutionStatus.QUEUED, + started_at=datetime.now(timezone.utc), + ended_at=datetime.now(timezone.utc), + stats=None, + ) + + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=[mock_exec_meta], + ) + + mock_add_graph_execution = mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.add_graph_execution", + return_value=AsyncMock(), + ) + + response = client.post( + "/admin/diagnostics/executions/requeue", + json={"execution_id": "exec-stuck-123"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["requeued_count"] == 1 + + # Verify it used add_graph_execution in requeue mode + mock_add_graph_execution.assert_called_once() + call_kwargs = mock_add_graph_execution.call_args.kwargs + assert call_kwargs["graph_exec_id"] == "exec-stuck-123" # Requeue mode! + assert call_kwargs["graph_id"] == "graph-456" + assert call_kwargs["user_id"] == "user-123" + + +def test_stop_single_execution_with_stop_graph_execution( + mocker: pytest_mock.MockFixture, + admin_user_id: str, +): + """Test stopping uses robust stop_graph_execution""" + mock_exec_meta = GraphExecutionMeta( + id="exec-running-123", + user_id="user-789", + graph_id="graph-999", + graph_version=2, + inputs=None, + credential_inputs=None, + nodes_input_masks=None, + preset_id=None, + status=AgentExecutionStatus.RUNNING, + started_at=datetime.now(timezone.utc), + ended_at=datetime.now(timezone.utc), + stats=None, + ) + + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=[mock_exec_meta], + ) + + mock_stop_graph_execution = mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.stop_graph_execution", + return_value=AsyncMock(), + ) + + response = client.post( + "/admin/diagnostics/executions/stop", + json={"execution_id": "exec-running-123"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 1 + + # Verify it used stop_graph_execution with cascade + mock_stop_graph_execution.assert_called_once() + call_kwargs = mock_stop_graph_execution.call_args.kwargs + assert call_kwargs["graph_exec_id"] == "exec-running-123" + assert call_kwargs["user_id"] == "user-789" + assert call_kwargs["cascade"] is True # Stops children too! + assert call_kwargs["wait_timeout"] == 15.0 + + +def test_requeue_not_queued_execution_fails( + mocker: pytest_mock.MockFixture, +): + """Test that requeue fails if execution is not in QUEUED status""" + # Mock an execution that's RUNNING (not QUEUED) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=[], # No QUEUED executions found + ) + + response = client.post( + "/admin/diagnostics/executions/requeue", + json={"execution_id": "exec-running-123"}, + ) + + assert response.status_code == 404 + assert "not found or not in QUEUED status" in response.json()["detail"] + + +def test_list_invalid_executions_no_bulk_actions( + mocker: pytest_mock.MockFixture, +): + """Verify invalid executions endpoint is read-only (no bulk actions)""" + # This is a documentation test - the endpoint exists but should not + # have corresponding cleanup/stop/requeue endpoints + + # These endpoints should NOT exist for invalid states: + invalid_bulk_endpoints = [ + "/admin/diagnostics/executions/cleanup-invalid", + "/admin/diagnostics/executions/stop-invalid", + "/admin/diagnostics/executions/requeue-invalid", + ] + + for endpoint in invalid_bulk_endpoints: + response = client.post(endpoint, json={"execution_ids": ["test"]}) + assert response.status_code == 404, f"{endpoint} should not exist (read-only)" + + +def test_execution_ids_filter_efficiency( + mocker: pytest_mock.MockFixture, +): + """Test that bulk operations use efficient execution_ids filter""" + mock_exec_metas = [ + GraphExecutionMeta( + id=f"exec-{i}", + user_id=f"user-{i}", + graph_id="graph-123", + graph_version=1, + inputs=None, + credential_inputs=None, + nodes_input_masks=None, + preset_id=None, + status=AgentExecutionStatus.QUEUED, + started_at=datetime.now(timezone.utc), + ended_at=datetime.now(timezone.utc), + stats=None, + ) + for i in range(3) + ] + + mock_get_graph_executions = mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=mock_exec_metas, + ) + + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.add_graph_execution", + return_value=AsyncMock(), + ) + + response = client.post( + "/admin/diagnostics/executions/requeue-bulk", + json={"execution_ids": ["exec-0", "exec-1", "exec-2"]}, + ) + + assert response.status_code == 200 + + # Verify it used execution_ids filter (not fetching all queued) + mock_get_graph_executions.assert_called_once() + call_kwargs = mock_get_graph_executions.call_args.kwargs + assert "execution_ids" in call_kwargs + assert call_kwargs["execution_ids"] == ["exec-0", "exec-1", "exec-2"] + assert call_kwargs["statuses"] == [AgentExecutionStatus.QUEUED] + + +# --------------------------------------------------------------------------- +# Helper: reusable mock diagnostics summary +# --------------------------------------------------------------------------- + + +def _make_mock_diagnostics(**overrides) -> ExecutionDiagnosticsSummary: + defaults = dict( + running_count=10, + queued_db_count=5, + rabbitmq_queue_depth=3, + cancel_queue_depth=0, + orphaned_running=2, + orphaned_queued=1, + failed_count_1h=5, + failed_count_24h=20, + failure_rate_24h=0.83, + stuck_running_24h=3, + stuck_running_1h=5, + oldest_running_hours=26.5, + stuck_queued_1h=2, + queued_never_started=1, + invalid_queued_with_start=1, + invalid_running_without_start=1, + completed_1h=50, + completed_24h=1200, + throughput_per_hour=50.0, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + defaults.update(overrides) + return ExecutionDiagnosticsSummary(**defaults) + + +_SENTINEL = object() + + +def _make_mock_execution( + exec_id: str = "exec-1", + status: str = "RUNNING", + started_at: datetime | None | object = _SENTINEL, +) -> RunningExecutionDetail: + return RunningExecutionDetail( + execution_id=exec_id, + graph_id="graph-123", + graph_name="Test Graph", + graph_version=1, + user_id="user-123", + user_email="test@example.com", + status=status, + created_at=datetime.now(timezone.utc), + started_at=( + datetime.now(timezone.utc) if started_at is _SENTINEL else started_at + ), + queue_status=None, + ) + + +def _make_mock_failed_execution( + exec_id: str = "exec-fail-1", +) -> FailedExecutionDetail: + return FailedExecutionDetail( + execution_id=exec_id, + graph_id="graph-123", + graph_name="Test Graph", + graph_version=1, + user_id="user-123", + user_email="test@example.com", + status="FAILED", + created_at=datetime.now(timezone.utc), + started_at=datetime.now(timezone.utc), + failed_at=datetime.now(timezone.utc), + error_message="Something went wrong", + ) + + +def _make_mock_schedule_health(**overrides) -> ScheduleHealthMetrics: + defaults = dict( + total_schedules=15, + user_schedules=10, + system_schedules=5, + orphaned_deleted_graph=2, + orphaned_no_library_access=1, + orphaned_invalid_credentials=0, + orphaned_validation_failed=0, + total_orphaned=3, + schedules_next_hour=4, + schedules_next_24h=8, + total_runs_next_hour=12, + total_runs_next_24h=48, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + defaults.update(overrides) + return ScheduleHealthMetrics(**defaults) + + +# --------------------------------------------------------------------------- +# GET endpoints: execution list variants +# --------------------------------------------------------------------------- + + +def test_list_running_executions(mocker: pytest_mock.MockFixture): + mock_execs = [ + _make_mock_execution("exec-run-1"), + _make_mock_execution("exec-run-2"), + ] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_running_executions_details", + return_value=mock_execs, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics", + return_value=_make_mock_diagnostics(), + ) + + response = client.get("/admin/diagnostics/executions/running?limit=50&offset=0") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 15 # running_count(10) + queued_db_count(5) + assert len(data["executions"]) == 2 + assert data["executions"][0]["execution_id"] == "exec-run-1" + + +def test_list_orphaned_executions(mocker: pytest_mock.MockFixture): + mock_execs = [_make_mock_execution("exec-orphan-1", status="RUNNING")] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_orphaned_executions_details", + return_value=mock_execs, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics", + return_value=_make_mock_diagnostics(), + ) + + response = client.get("/admin/diagnostics/executions/orphaned?limit=50&offset=0") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 3 # orphaned_running(2) + orphaned_queued(1) + assert len(data["executions"]) == 1 + + +def test_list_failed_executions(mocker: pytest_mock.MockFixture): + mock_execs = [_make_mock_failed_execution("exec-fail-1")] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_failed_executions_details", + return_value=mock_execs, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_failed_executions_count", + return_value=42, + ) + + response = client.get( + "/admin/diagnostics/executions/failed?limit=50&offset=0&hours=24" + ) + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 42 + assert len(data["executions"]) == 1 + assert data["executions"][0]["error_message"] == "Something went wrong" + + +def test_list_long_running_executions(mocker: pytest_mock.MockFixture): + mock_execs = [_make_mock_execution("exec-long-1")] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_long_running_executions_details", + return_value=mock_execs, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics", + return_value=_make_mock_diagnostics(), + ) + + response = client.get( + "/admin/diagnostics/executions/long-running?limit=50&offset=0" + ) + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 3 # stuck_running_24h + assert len(data["executions"]) == 1 + + +def test_list_stuck_queued_executions(mocker: pytest_mock.MockFixture): + mock_execs = [ + _make_mock_execution("exec-stuck-1", status="QUEUED", started_at=None) + ] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_stuck_queued_executions_details", + return_value=mock_execs, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics", + return_value=_make_mock_diagnostics(), + ) + + response = client.get( + "/admin/diagnostics/executions/stuck-queued?limit=50&offset=0" + ) + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 2 # stuck_queued_1h + assert len(data["executions"]) == 1 + + +# --------------------------------------------------------------------------- +# GET endpoints: agent + schedule diagnostics +# --------------------------------------------------------------------------- + + +def test_get_agent_diagnostics(mocker: pytest_mock.MockFixture): + mock_diag = AgentDiagnosticsSummary( + agents_with_active_executions=7, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_agent_diagnostics", + return_value=mock_diag, + ) + + response = client.get("/admin/diagnostics/agents") + + assert response.status_code == 200 + data = response.json() + assert data["agents_with_active_executions"] == 7 + + +def test_get_schedule_diagnostics(mocker: pytest_mock.MockFixture): + mock_metrics = _make_mock_schedule_health() + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_schedule_health_metrics", + return_value=mock_metrics, + ) + + response = client.get("/admin/diagnostics/schedules") + + assert response.status_code == 200 + data = response.json() + assert data["user_schedules"] == 10 + assert data["total_orphaned"] == 3 + assert data["total_runs_next_hour"] == 12 + + +def test_list_all_schedules(mocker: pytest_mock.MockFixture): + mock_schedules = [ + ScheduleDetail( + schedule_id="sched-1", + schedule_name="Daily Run", + graph_id="graph-1", + graph_name="My Agent", + graph_version=1, + user_id="user-1", + user_email="alice@example.com", + cron="0 9 * * *", + timezone="UTC", + next_run_time=datetime.now(timezone.utc).isoformat(), + ), + ] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_all_schedules_details", + return_value=mock_schedules, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_schedule_health_metrics", + return_value=_make_mock_schedule_health(), + ) + + response = client.get("/admin/diagnostics/schedules/all?limit=50&offset=0") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 10 + assert len(data["schedules"]) == 1 + assert data["schedules"][0]["schedule_name"] == "Daily Run" + + +def test_list_orphaned_schedules(mocker: pytest_mock.MockFixture): + mock_orphans = [ + OrphanedScheduleDetail( + schedule_id="sched-orphan-1", + schedule_name="Ghost Schedule", + graph_id="graph-deleted", + graph_version=1, + user_id="user-1", + orphan_reason="deleted_graph", + error_detail=None, + next_run_time=datetime.now(timezone.utc).isoformat(), + ), + ] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_orphaned_schedules_details", + return_value=mock_orphans, + ) + + response = client.get("/admin/diagnostics/schedules/orphaned") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 1 + assert data["schedules"][0]["orphan_reason"] == "deleted_graph" + + +# --------------------------------------------------------------------------- +# POST endpoints: bulk stop, cleanup, requeue +# --------------------------------------------------------------------------- + + +def test_stop_multiple_executions(mocker: pytest_mock.MockFixture): + mock_exec_metas = [ + GraphExecutionMeta( + id=f"exec-{i}", + user_id=f"user-{i}", + graph_id="graph-123", + graph_version=1, + inputs=None, + credential_inputs=None, + nodes_input_masks=None, + preset_id=None, + status=AgentExecutionStatus.RUNNING, + started_at=datetime.now(timezone.utc), + ended_at=None, + stats=None, + ) + for i in range(2) + ] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=mock_exec_metas, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.stop_graph_execution", + return_value=AsyncMock(), + ) + + response = client.post( + "/admin/diagnostics/executions/stop-bulk", + json={"execution_ids": ["exec-0", "exec-1"]}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 2 + + +def test_stop_multiple_executions_none_found(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=[], + ) + + response = client.post( + "/admin/diagnostics/executions/stop-bulk", + json={"execution_ids": ["nonexistent"]}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is False + assert data["stopped_count"] == 0 + + +def test_cleanup_orphaned_executions(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_executions_bulk", + return_value=3, + ) + + response = client.post( + "/admin/diagnostics/executions/cleanup-orphaned", + json={"execution_ids": ["exec-1", "exec-2", "exec-3"]}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 3 + + +def test_cleanup_orphaned_schedules(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_schedules_bulk", + return_value=2, + ) + + response = client.post( + "/admin/diagnostics/schedules/cleanup-orphaned", + json={"schedule_ids": ["sched-1", "sched-2"]}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["deleted_count"] == 2 + + +def test_stop_all_long_running_executions(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.stop_all_long_running_executions", + return_value=5, + ) + + response = client.post("/admin/diagnostics/executions/stop-all-long-running") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 5 + + +def test_cleanup_all_orphaned_executions(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_all_orphaned_execution_ids", + return_value=["exec-1", "exec-2"], + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_executions_bulk", + return_value=2, + ) + + response = client.post("/admin/diagnostics/executions/cleanup-all-orphaned") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 2 + + +def test_cleanup_all_orphaned_executions_none(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_all_orphaned_execution_ids", + return_value=[], + ) + + response = client.post("/admin/diagnostics/executions/cleanup-all-orphaned") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 0 + assert "No orphaned" in data["message"] + + +def test_cleanup_all_stuck_queued_executions(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.cleanup_all_stuck_queued_executions", + return_value=4, + ) + + response = client.post("/admin/diagnostics/executions/cleanup-all-stuck-queued") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 4 + + +def test_requeue_all_stuck_executions(mocker: pytest_mock.MockFixture): + mock_exec_metas = [ + GraphExecutionMeta( + id=f"exec-stuck-{i}", + user_id=f"user-{i}", + graph_id="graph-123", + graph_version=1, + inputs=None, + credential_inputs=None, + nodes_input_masks=None, + preset_id=None, + status=AgentExecutionStatus.QUEUED, + started_at=None, + ended_at=None, + stats=None, + ) + for i in range(3) + ] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_all_stuck_queued_execution_ids", + return_value=["exec-stuck-0", "exec-stuck-1", "exec-stuck-2"], + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=mock_exec_metas, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.add_graph_execution", + return_value=AsyncMock(), + ) + + response = client.post("/admin/diagnostics/executions/requeue-all-stuck") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["requeued_count"] == 3 + + +def test_requeue_all_stuck_executions_none(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_all_stuck_queued_execution_ids", + return_value=[], + ) + + response = client.post("/admin/diagnostics/executions/requeue-all-stuck") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["requeued_count"] == 0 + assert "No stuck" in data["message"] + + +def test_requeue_bulk_none_found(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=[], + ) + + response = client.post( + "/admin/diagnostics/executions/requeue-bulk", + json={"execution_ids": ["nonexistent"]}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is False + assert data["requeued_count"] == 0 + + +def test_stop_single_execution_not_found(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=[], + ) + + response = client.post( + "/admin/diagnostics/executions/stop", + json={"execution_id": "nonexistent"}, + ) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"] diff --git a/autogpt_platform/backend/backend/api/features/admin/model.py b/autogpt_platform/backend/backend/api/features/admin/model.py index 82f51e8e7a..c96c6d6433 100644 --- a/autogpt_platform/backend/backend/api/features/admin/model.py +++ b/autogpt_platform/backend/backend/api/features/admin/model.py @@ -14,3 +14,70 @@ class UserHistoryResponse(BaseModel): class AddUserCreditsResponse(BaseModel): new_balance: int transaction_key: str + + +class ExecutionDiagnosticsResponse(BaseModel): + """Response model for execution diagnostics""" + + # Current execution state + running_executions: int + queued_executions_db: int + queued_executions_rabbitmq: int + cancel_queue_depth: int + + # Orphaned execution detection + orphaned_running: int + orphaned_queued: int + + # Failure metrics + failed_count_1h: int + failed_count_24h: int + failure_rate_24h: float + + # Long-running detection + stuck_running_24h: int + stuck_running_1h: int + oldest_running_hours: float | None + + # Stuck queued detection + stuck_queued_1h: int + queued_never_started: int + + # Invalid state detection (data corruption - no auto-actions) + invalid_queued_with_start: int + invalid_running_without_start: int + + # Throughput metrics + completed_1h: int + completed_24h: int + throughput_per_hour: float + + timestamp: str + + +class AgentDiagnosticsResponse(BaseModel): + """Response model for agent diagnostics""" + + agents_with_active_executions: int + timestamp: str + + +class ScheduleHealthMetrics(BaseModel): + """Response model for schedule diagnostics""" + + total_schedules: int + user_schedules: int + system_schedules: int + + # Orphan detection + orphaned_deleted_graph: int + orphaned_no_library_access: int + orphaned_invalid_credentials: int + orphaned_validation_failed: int + total_orphaned: int + + # Upcoming + schedules_next_hour: int + schedules_next_24h: int + + timestamp: str diff --git a/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes.py b/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes.py index 379b9e9257..3b9c762f21 100644 --- a/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes.py +++ b/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes.py @@ -32,10 +32,10 @@ router = APIRouter( class UserRateLimitResponse(BaseModel): user_id: str user_email: Optional[str] = None - daily_token_limit: int - weekly_token_limit: int - daily_tokens_used: int - weekly_tokens_used: int + daily_cost_limit_microdollars: int + weekly_cost_limit_microdollars: int + daily_cost_used_microdollars: int + weekly_cost_used_microdollars: int tier: SubscriptionTier @@ -101,17 +101,19 @@ async def get_user_rate_limit( logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id) daily_limit, weekly_limit, tier = await get_global_rate_limits( - resolved_id, config.daily_token_limit, config.weekly_token_limit + resolved_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier) return UserRateLimitResponse( user_id=resolved_id, user_email=resolved_email, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, - daily_tokens_used=usage.daily.used, - weekly_tokens_used=usage.weekly.used, + daily_cost_limit_microdollars=daily_limit, + weekly_cost_limit_microdollars=weekly_limit, + daily_cost_used_microdollars=usage.daily.used, + weekly_cost_used_microdollars=usage.weekly.used, tier=tier, ) @@ -141,7 +143,9 @@ async def reset_user_rate_limit( raise HTTPException(status_code=500, detail="Failed to reset usage") from e daily_limit, weekly_limit, tier = await get_global_rate_limits( - user_id, config.daily_token_limit, config.weekly_token_limit + user_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier) @@ -154,10 +158,10 @@ async def reset_user_rate_limit( return UserRateLimitResponse( user_id=user_id, user_email=resolved_email, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, - daily_tokens_used=usage.daily.used, - weekly_tokens_used=usage.weekly.used, + daily_cost_limit_microdollars=daily_limit, + weekly_cost_limit_microdollars=weekly_limit, + daily_cost_used_microdollars=usage.daily.used, + weekly_cost_used_microdollars=usage.weekly.used, tier=tier, ) diff --git a/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes_test.py b/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes_test.py index 77e4a656fb..95c3e589cb 100644 --- a/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes_test.py +++ b/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes_test.py @@ -57,7 +57,7 @@ def _patch_rate_limit_deps( mocker.patch( f"{_MOCK_MODULE}.get_global_rate_limits", new_callable=AsyncMock, - return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE), + return_value=(2_500_000, 12_500_000, SubscriptionTier.BASIC), ) mocker.patch( f"{_MOCK_MODULE}.get_usage_status", @@ -85,11 +85,11 @@ def test_get_rate_limit( data = response.json() assert data["user_id"] == target_user_id assert data["user_email"] == _TARGET_EMAIL - assert data["daily_token_limit"] == 2_500_000 - assert data["weekly_token_limit"] == 12_500_000 - assert data["daily_tokens_used"] == 500_000 - assert data["weekly_tokens_used"] == 3_000_000 - assert data["tier"] == "FREE" + assert data["daily_cost_limit_microdollars"] == 2_500_000 + assert data["weekly_cost_limit_microdollars"] == 12_500_000 + assert data["daily_cost_used_microdollars"] == 500_000 + assert data["weekly_cost_used_microdollars"] == 3_000_000 + assert data["tier"] == "BASIC" configured_snapshot.assert_match( json.dumps(data, indent=2, sort_keys=True) + "\n", @@ -117,7 +117,7 @@ def test_get_rate_limit_by_email( data = response.json() assert data["user_id"] == target_user_id assert data["user_email"] == _TARGET_EMAIL - assert data["daily_token_limit"] == 2_500_000 + assert data["daily_cost_limit_microdollars"] == 2_500_000 def test_get_rate_limit_by_email_not_found( @@ -160,10 +160,10 @@ def test_reset_user_usage_daily_only( assert response.status_code == 200 data = response.json() - assert data["daily_tokens_used"] == 0 + assert data["daily_cost_used_microdollars"] == 0 # Weekly is untouched - assert data["weekly_tokens_used"] == 3_000_000 - assert data["tier"] == "FREE" + assert data["weekly_cost_used_microdollars"] == 3_000_000 + assert data["tier"] == "BASIC" mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False) @@ -192,9 +192,9 @@ def test_reset_user_usage_daily_and_weekly( assert response.status_code == 200 data = response.json() - assert data["daily_tokens_used"] == 0 - assert data["weekly_tokens_used"] == 0 - assert data["tier"] == "FREE" + assert data["daily_cost_used_microdollars"] == 0 + assert data["weekly_cost_used_microdollars"] == 0 + assert data["tier"] == "BASIC" mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True) @@ -231,7 +231,7 @@ def test_get_rate_limit_email_lookup_failure( mocker.patch( f"{_MOCK_MODULE}.get_global_rate_limits", new_callable=AsyncMock, - return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE), + return_value=(2_500_000, 12_500_000, SubscriptionTier.BASIC), ) mocker.patch( f"{_MOCK_MODULE}.get_usage_status", @@ -324,7 +324,7 @@ def test_set_user_tier( mocker.patch( f"{_MOCK_MODULE}.get_user_tier", new_callable=AsyncMock, - return_value=SubscriptionTier.FREE, + return_value=SubscriptionTier.BASIC, ) mock_set = mocker.patch( f"{_MOCK_MODULE}.set_user_tier", @@ -347,7 +347,7 @@ def test_set_user_tier_downgrade( mocker: pytest_mock.MockerFixture, target_user_id: str, ) -> None: - """Test downgrading a user's tier from PRO to FREE.""" + """Test downgrading a user's tier from PRO to BASIC.""" mocker.patch( f"{_MOCK_MODULE}.get_user_email_by_id", new_callable=AsyncMock, @@ -365,14 +365,14 @@ def test_set_user_tier_downgrade( response = client.post( "/admin/rate_limit/tier", - json={"user_id": target_user_id, "tier": "FREE"}, + json={"user_id": target_user_id, "tier": "BASIC"}, ) assert response.status_code == 200 data = response.json() assert data["user_id"] == target_user_id - assert data["tier"] == "FREE" - mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.FREE) + assert data["tier"] == "BASIC" + mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.BASIC) def test_set_user_tier_invalid_tier( @@ -456,7 +456,7 @@ def test_set_user_tier_db_failure( mocker.patch( f"{_MOCK_MODULE}.get_user_tier", new_callable=AsyncMock, - return_value=SubscriptionTier.FREE, + return_value=SubscriptionTier.BASIC, ) mocker.patch( f"{_MOCK_MODULE}.set_user_tier", diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index eceedb828c..d317d677a5 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -13,6 +13,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator from backend.copilot import service as chat_service from backend.copilot import stream_registry +from backend.copilot.builder_context import resolve_session_permissions from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode from backend.copilot.db import get_chat_messages_paginated from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn @@ -24,6 +25,7 @@ from backend.copilot.model import ( create_chat_session, delete_chat_session, get_chat_session, + get_or_create_builder_session, get_user_sessions, update_session_title, ) @@ -34,7 +36,7 @@ from backend.copilot.pending_message_helpers import ( ) from backend.copilot.pending_messages import peek_pending_messages from backend.copilot.rate_limit import ( - CoPilotUsageStatus, + CoPilotUsagePublic, RateLimitExceeded, acquire_reset_lock, check_rate_limit, @@ -73,6 +75,7 @@ from backend.copilot.tools.models import ( NoResultsResponse, SetupRequirementsResponse, SuggestedGoalResponse, + TodoWriteResponse, UnderstandingUpdatedResponse, ) from backend.copilot.tracking import track_user_message @@ -133,7 +136,7 @@ def _strip_injected_context(message: dict) -> dict: class StreamChatRequest(BaseModel): """Request model for streaming chat with optional context.""" - message: str + message: str = Field(max_length=64_000) is_user_message: bool = True context: dict[str, str] | None = None # {url: str, content: str} file_ids: list[str] | None = Field( @@ -165,15 +168,31 @@ class PeekPendingMessagesResponse(BaseModel): class CreateSessionRequest(BaseModel): - """Request model for creating a new chat session. + """Request model for creating (or get-or-creating) a chat session. + + Two modes, selected by the body: + + - Default: create a fresh session. ``dry_run`` is a **top-level** + field — do not nest it inside ``metadata``. + - Builder-bound: when ``builder_graph_id`` is set, the endpoint + switches to **get-or-create** keyed on + ``(user_id, builder_graph_id)``. The builder panel calls this on + mount so the chat persists across refreshes. Graph ownership is + validated inside :func:`get_or_create_builder_session`. Write-side + scope is enforced per-tool (``edit_agent`` / ``run_agent`` reject + any ``agent_id`` other than the bound graph) and a small blacklist + hides tools that conflict with the panel's scope + (``create_agent`` / ``customize_agent`` / ``get_agent_building_guide`` + — see :data:`BUILDER_BLOCKED_TOOLS`). Read-side lookups + (``find_block``, ``find_agent``, ``search_docs``, …) stay open. - ``dry_run`` is a **top-level** field — do not nest it inside ``metadata``. Extra/unknown fields are rejected (422) to prevent silent mis-use. """ model_config = ConfigDict(extra="forbid") dry_run: bool = False + builder_graph_id: str | None = Field(default=None, max_length=128) class CreateSessionResponse(BaseModel): @@ -318,29 +337,43 @@ async def create_session( user_id: Annotated[str, Security(auth.get_user_id)], request: CreateSessionRequest | None = None, ) -> CreateSessionResponse: - """ - Create a new chat session. + """Create (or get-or-create) a chat session. - Initiates a new chat session for the authenticated user. + Two modes, selected by the request body: + + - Default: create a fresh session for the user. ``dry_run=True`` forces + run_block and run_agent calls to use dry-run simulation. + - Builder-bound: when ``builder_graph_id`` is set, get-or-create keyed + on ``(user_id, builder_graph_id)``. Returns the existing session for + that graph or creates one locked to it. Graph ownership is validated + inside :func:`get_or_create_builder_session`; raises 404 on + unauthorized access. Write-side scope is enforced per-tool + (``edit_agent`` / ``run_agent`` reject any ``agent_id`` other than + the bound graph) and a small blacklist hides tools that conflict + with the panel's scope (see :data:`BUILDER_BLOCKED_TOOLS`). Args: user_id: The authenticated user ID parsed from the JWT (required). - request: Optional request body. When provided, ``dry_run=True`` - forces run_block and run_agent calls to use dry-run simulation. + request: Optional request body with ``dry_run`` and/or + ``builder_graph_id``. Returns: - CreateSessionResponse: Details of the created session. - + CreateSessionResponse: Details of the resulting session. """ dry_run = request.dry_run if request else False + builder_graph_id = request.builder_graph_id if request else None logger.info( f"Creating session with user_id: " f"...{user_id[-8:] if len(user_id) > 8 else ''}" f"{', dry_run=True' if dry_run else ''}" + f"{f', builder_graph_id={builder_graph_id}' if builder_graph_id else ''}" ) - session = await create_chat_session(user_id, dry_run=dry_run) + if builder_graph_id: + session = await get_or_create_builder_session(user_id, builder_graph_id) + else: + session = await create_chat_session(user_id, dry_run=dry_run) return CreateSessionResponse( id=session.session_id, @@ -536,23 +569,27 @@ async def get_session( ) async def get_copilot_usage( user_id: Annotated[str, Security(auth.get_user_id)], -) -> CoPilotUsageStatus: +) -> CoPilotUsagePublic: """Get CoPilot usage status for the authenticated user. - Returns current token usage vs limits for daily and weekly windows. - Global defaults sourced from LaunchDarkly (falling back to config). - Includes the user's rate-limit tier. + Returns the percentage of the daily/weekly allowance used — not the + raw spend or cap — so clients cannot derive per-turn cost or platform + margins. Global defaults sourced from LaunchDarkly (falling back to + config). Includes the user's rate-limit tier. """ daily_limit, weekly_limit, tier = await get_global_rate_limits( - user_id, config.daily_token_limit, config.weekly_token_limit + user_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) - return await get_usage_status( + status = await get_usage_status( user_id=user_id, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, + daily_cost_limit=daily_limit, + weekly_cost_limit=weekly_limit, rate_limit_reset_cost=config.rate_limit_reset_cost, tier=tier, ) + return CoPilotUsagePublic.from_status(status) class RateLimitResetResponse(BaseModel): @@ -561,7 +598,9 @@ class RateLimitResetResponse(BaseModel): success: bool credits_charged: int = Field(description="Credits charged (in cents)") remaining_balance: int = Field(description="Credit balance after charge (in cents)") - usage: CoPilotUsageStatus = Field(description="Updated usage status after reset") + usage: CoPilotUsagePublic = Field( + description="Updated usage status after reset (percentages only)" + ) @router.post( @@ -585,7 +624,7 @@ async def reset_copilot_usage( ) -> RateLimitResetResponse: """Reset the daily CoPilot rate limit by spending credits. - Allows users who have hit their daily token limit to spend credits + Allows users who have hit their daily cost limit to spend credits to reset their daily usage counter and continue working. Returns 400 if the feature is disabled or the user is not over the limit. Returns 402 if the user has insufficient credits. @@ -604,7 +643,9 @@ async def reset_copilot_usage( ) daily_limit, weekly_limit, tier = await get_global_rate_limits( - user_id, config.daily_token_limit, config.weekly_token_limit + user_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) if daily_limit <= 0: @@ -641,8 +682,8 @@ async def reset_copilot_usage( # used for limit checks, not returned to the client.) usage_status = await get_usage_status( user_id=user_id, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, + daily_cost_limit=daily_limit, + weekly_cost_limit=weekly_limit, tier=tier, ) if daily_limit > 0 and usage_status.daily.used < daily_limit: @@ -677,7 +718,7 @@ async def reset_copilot_usage( # Reset daily usage in Redis. If this fails, refund the credits # so the user is not charged for a service they did not receive. - if not await reset_daily_usage(user_id, daily_token_limit=daily_limit): + if not await reset_daily_usage(user_id, daily_cost_limit=daily_limit): # Compensate: refund the charged credits. refunded = False try: @@ -713,11 +754,11 @@ async def reset_copilot_usage( finally: await release_reset_lock(user_id) - # Return updated usage status. + # Return updated usage status (public schema — percentages only). updated_usage = await get_usage_status( user_id=user_id, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, + daily_cost_limit=daily_limit, + weekly_cost_limit=weekly_limit, rate_limit_reset_cost=config.rate_limit_reset_cost, tier=tier, ) @@ -726,7 +767,7 @@ async def reset_copilot_usage( success=True, credits_charged=cost, remaining_balance=remaining, - usage=updated_usage, + usage=CoPilotUsagePublic.from_status(updated_usage), ) @@ -787,7 +828,7 @@ async def cancel_session_task( ), }, 404: {"description": "Session not found or access denied"}, - 429: {"description": "Token rate-limit or call-frequency cap exceeded"}, + 429: {"description": "Cost rate-limit or call-frequency cap exceeded"}, }, ) async def stream_chat_post( @@ -830,7 +871,8 @@ async def stream_chat_post( f"user={user_id}, message_len={len(request.message)}", extra={"json_fields": log_meta}, ) - await _validate_and_get_session(session_id, user_id) + session = await _validate_and_get_session(session_id, user_id) + builder_permissions = resolve_session_permissions(session) # Self-defensive queue-fallback: if a turn is already running, don't race # it on the cluster lock — drop the message into the pending buffer and @@ -861,18 +903,20 @@ async def stream_chat_post( }, ) - # Pre-turn rate limit check (token-based). + # Pre-turn rate limit check (cost-based, microdollars). # check_rate_limit short-circuits internally when both limits are 0. # Global defaults sourced from LaunchDarkly, falling back to config. if user_id: try: daily_limit, weekly_limit, _ = await get_global_rate_limits( - user_id, config.daily_token_limit, config.weekly_token_limit + user_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) await check_rate_limit( user_id=user_id, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, + daily_cost_limit=daily_limit, + weekly_cost_limit=weekly_limit, ) except RateLimitExceeded as e: raise HTTPException(status_code=429, detail=str(e)) from e @@ -943,6 +987,7 @@ async def stream_chat_post( file_ids=sanitized_file_ids, mode=request.mode, model=request.model, + permissions=builder_permissions, request_arrival_at=request_arrival_at, ) else: @@ -1375,6 +1420,7 @@ ToolResponseUnion = ( | MemorySearchResponse | MemoryForgetCandidatesResponse | MemoryForgetConfirmResponse + | TodoWriteResponse ) diff --git a/autogpt_platform/backend/backend/api/features/chat/routes_test.py b/autogpt_platform/backend/backend/api/features/chat/routes_test.py index 4dc6547515..1f692ab299 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes_test.py @@ -11,10 +11,20 @@ import pytest_mock from backend.api.features.chat import routes as chat_routes from backend.api.features.chat.routes import _strip_injected_context from backend.copilot.rate_limit import SubscriptionTier +from backend.util.exceptions import NotFoundError app = fastapi.FastAPI() app.include_router(chat_routes.router) + +@app.exception_handler(NotFoundError) +async def _not_found_handler( + request: fastapi.Request, exc: NotFoundError +) -> fastapi.responses.JSONResponse: + """Mirror the production NotFoundError → 404 mapping from the REST app.""" + return fastapi.responses.JSONResponse(status_code=404, content={"detail": str(exc)}) + + client = fastapi.testclient.TestClient(app) TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a" @@ -296,8 +306,8 @@ def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockerF _mock_stream_internals(mocker) # Ensure the rate-limit branch is entered by setting a non-zero limit. - mocker.patch.object(chat_routes.config, "daily_token_limit", 10000) - mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000) + mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000) + mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000) mocker.patch( "backend.api.features.chat.routes.check_rate_limit", side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)), @@ -318,8 +328,8 @@ def test_stream_chat_returns_429_on_weekly_rate_limit( from backend.copilot.rate_limit import RateLimitExceeded _mock_stream_internals(mocker) - mocker.patch.object(chat_routes.config, "daily_token_limit", 10000) - mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000) + mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000) + mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000) resets_at = datetime.now(UTC) + timedelta(days=3) mocker.patch( "backend.api.features.chat.routes.check_rate_limit", @@ -341,8 +351,8 @@ def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockerFixture): from backend.copilot.rate_limit import RateLimitExceeded _mock_stream_internals(mocker) - mocker.patch.object(chat_routes.config, "daily_token_limit", 10000) - mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000) + mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000) + mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000) mocker.patch( "backend.api.features.chat.routes.check_rate_limit", side_effect=RateLimitExceeded( @@ -370,7 +380,7 @@ def _mock_usage( weekly_used: int = 2000, daily_limit: int = 10000, weekly_limit: int = 50000, - tier: "SubscriptionTier" = SubscriptionTier.FREE, + tier: "SubscriptionTier" = SubscriptionTier.BASIC, ) -> AsyncMock: """Mock get_usage_status and get_global_rate_limits for usage endpoint tests. @@ -402,25 +412,35 @@ def test_usage_returns_daily_and_weekly( mocker: pytest_mock.MockerFixture, test_user_id: str, ) -> None: - """GET /usage returns daily and weekly usage.""" + """GET /usage returns percentages for daily and weekly windows only. + + The raw used/limit microdollar values MUST NOT leak — clients should not + be able to derive per-turn cost or platform margins from the public API. + """ mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000) - mocker.patch.object(chat_routes.config, "daily_token_limit", 10000) - mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000) + mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000) + mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000) response = client.get("/usage") assert response.status_code == 200 data = response.json() - assert data["daily"]["used"] == 500 - assert data["weekly"]["used"] == 2000 + # 500 / 10000 = 5%, 2000 / 50000 = 4% + assert data["daily"]["percent_used"] == 5.0 + assert data["weekly"]["percent_used"] == 4.0 + # Raw spend/limit must not be exposed. + assert "used" not in data["daily"] + assert "limit" not in data["daily"] + assert "used" not in data["weekly"] + assert "limit" not in data["weekly"] mock_get.assert_called_once_with( user_id=test_user_id, - daily_token_limit=10000, - weekly_token_limit=50000, + daily_cost_limit=10000, + weekly_cost_limit=50000, rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost, - tier=SubscriptionTier.FREE, + tier=SubscriptionTier.BASIC, ) @@ -438,10 +458,10 @@ def test_usage_uses_config_limits( assert response.status_code == 200 mock_get.assert_called_once_with( user_id=test_user_id, - daily_token_limit=99999, - weekly_token_limit=77777, + daily_cost_limit=99999, + weekly_cost_limit=77777, rate_limit_reset_cost=500, - tier=SubscriptionTier.FREE, + tier=SubscriptionTier.BASIC, ) @@ -954,6 +974,618 @@ class TestStripInjectedContext: assert result["content"] == "hello" +# ─── message max_length validation ─────────────────────────────────── + + +def test_stream_chat_rejects_too_long_message(): + """A message exceeding max_length=64_000 must be rejected (422).""" + response = client.post( + "/sessions/sess-1/stream", + json={ + "message": "x" * 64_001, + }, + ) + assert response.status_code == 422 + + +def test_stream_chat_accepts_exactly_max_length_message( + mocker: pytest_mock.MockFixture, +): + """A message exactly at max_length=64_000 must be accepted.""" + _mock_stream_internals(mocker) + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(0, 0, SubscriptionTier.BASIC), + ) + + response = client.post( + "/sessions/sess-1/stream", + json={ + "message": "x" * 64_000, + }, + ) + assert response.status_code == 200 + + +# ─── list_sessions ──────────────────────────────────────────────────── + + +def _make_session_info(session_id: str = "sess-1", title: str | None = "Test"): + """Build a minimal ChatSessionInfo-like mock.""" + from backend.copilot.model import ChatSessionInfo, ChatSessionMetadata + + return ChatSessionInfo( + session_id=session_id, + user_id=TEST_USER_ID, + title=title, + usage=[], + started_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + metadata=ChatSessionMetadata(), + ) + + +def test_list_sessions_returns_sessions(mocker: pytest_mock.MockerFixture) -> None: + """GET /sessions returns list of sessions with is_processing=False when Redis OK.""" + session = _make_session_info("sess-abc") + mocker.patch( + "backend.api.features.chat.routes.get_user_sessions", + new_callable=AsyncMock, + return_value=([session], 1), + ) + # Redis pipeline returns "done" (not "running") for this session + mock_redis = MagicMock() + mock_pipe = MagicMock() + mock_pipe.hget = MagicMock(return_value=None) + mock_pipe.execute = AsyncMock(return_value=["done"]) + mock_redis.pipeline = MagicMock(return_value=mock_pipe) + mocker.patch( + "backend.api.features.chat.routes.get_redis_async", + new_callable=AsyncMock, + return_value=mock_redis, + ) + + response = client.get("/sessions") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 1 + assert len(data["sessions"]) == 1 + assert data["sessions"][0]["id"] == "sess-abc" + assert data["sessions"][0]["is_processing"] is False + + +def test_list_sessions_marks_running_as_processing( + mocker: pytest_mock.MockerFixture, +) -> None: + """Sessions with Redis status='running' should have is_processing=True.""" + session = _make_session_info("sess-xyz") + mocker.patch( + "backend.api.features.chat.routes.get_user_sessions", + new_callable=AsyncMock, + return_value=([session], 1), + ) + mock_redis = MagicMock() + mock_pipe = MagicMock() + mock_pipe.hget = MagicMock(return_value=None) + mock_pipe.execute = AsyncMock(return_value=["running"]) + mock_redis.pipeline = MagicMock(return_value=mock_pipe) + mocker.patch( + "backend.api.features.chat.routes.get_redis_async", + new_callable=AsyncMock, + return_value=mock_redis, + ) + + response = client.get("/sessions") + + assert response.status_code == 200 + assert response.json()["sessions"][0]["is_processing"] is True + + +def test_list_sessions_redis_failure_defaults_to_not_processing( + mocker: pytest_mock.MockerFixture, +) -> None: + """Redis failures must be swallowed and sessions default to is_processing=False.""" + session = _make_session_info("sess-fallback") + mocker.patch( + "backend.api.features.chat.routes.get_user_sessions", + new_callable=AsyncMock, + return_value=([session], 1), + ) + mocker.patch( + "backend.api.features.chat.routes.get_redis_async", + side_effect=Exception("Redis down"), + ) + + response = client.get("/sessions") + + assert response.status_code == 200 + assert response.json()["sessions"][0]["is_processing"] is False + + +def test_list_sessions_empty(mocker: pytest_mock.MockerFixture) -> None: + """GET /sessions with no sessions returns empty list without hitting Redis.""" + mocker.patch( + "backend.api.features.chat.routes.get_user_sessions", + new_callable=AsyncMock, + return_value=([], 0), + ) + + response = client.get("/sessions") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 0 + assert data["sessions"] == [] + + +# ─── delete_session ─────────────────────────────────────────────────── + + +def test_delete_session_success(mocker: pytest_mock.MockerFixture) -> None: + """DELETE /sessions/{id} returns 204 when deleted successfully.""" + mocker.patch( + "backend.api.features.chat.routes.delete_chat_session", + new_callable=AsyncMock, + return_value=True, + ) + # Patch use_e2b_sandbox env-var to disable E2B so the route skips sandbox cleanup. + # Patching the Pydantic property directly doesn't work (Pydantic v2 intercepts + # attribute setting on BaseSettings instances and raises AttributeError). + mocker.patch.dict("os.environ", {"USE_E2B_SANDBOX": "false"}) + + response = client.delete("/sessions/sess-1") + + assert response.status_code == 204 + + +def test_delete_session_not_found(mocker: pytest_mock.MockerFixture) -> None: + """DELETE /sessions/{id} returns 404 when session not found or not owned.""" + mocker.patch( + "backend.api.features.chat.routes.delete_chat_session", + new_callable=AsyncMock, + return_value=False, + ) + + response = client.delete("/sessions/sess-missing") + + assert response.status_code == 404 + + +# ─── cancel_session_task ────────────────────────────────────────────── + + +def _mock_validate_session( + mocker: pytest_mock.MockerFixture, *, session_id: str = "sess-1" +): + """Mock _validate_and_get_session to return a dummy session.""" + from backend.copilot.model import ChatSession + + dummy = ChatSession.new(TEST_USER_ID, dry_run=False) + mocker.patch( + "backend.api.features.chat.routes._validate_and_get_session", + new_callable=AsyncMock, + return_value=dummy, + ) + + +def test_cancel_session_no_active_task(mocker: pytest_mock.MockerFixture) -> None: + """Cancel returns cancelled=True with reason when no stream is active.""" + _mock_validate_session(mocker) + mock_registry = MagicMock() + mock_registry.get_active_session = AsyncMock(return_value=(None, None)) + mocker.patch("backend.api.features.chat.routes.stream_registry", mock_registry) + + response = client.post("/sessions/sess-1/cancel") + + assert response.status_code == 200 + data = response.json() + assert data["cancelled"] is True + assert data["reason"] == "no_active_session" + + +def test_cancel_session_enqueues_cancel_and_confirms( + mocker: pytest_mock.MockerFixture, +) -> None: + """Cancel enqueues cancel task and returns cancelled=True once stream stops.""" + from backend.copilot.stream_registry import ActiveSession + + _mock_validate_session(mocker) + active_session = ActiveSession( + session_id="sess-1", + user_id=TEST_USER_ID, + tool_call_id="chat_stream", + tool_name="chat", + turn_id="turn-1", + status="running", + ) + stopped_session = ActiveSession( + session_id="sess-1", + user_id=TEST_USER_ID, + tool_call_id="chat_stream", + tool_name="chat", + turn_id="turn-1", + status="completed", + ) + mock_registry = MagicMock() + mock_registry.get_active_session = AsyncMock(return_value=(active_session, "1-0")) + mock_registry.get_session = AsyncMock(return_value=stopped_session) + mocker.patch("backend.api.features.chat.routes.stream_registry", mock_registry) + mock_enqueue = mocker.patch( + "backend.api.features.chat.routes.enqueue_cancel_task", + new_callable=AsyncMock, + ) + + response = client.post("/sessions/sess-1/cancel") + + assert response.status_code == 200 + assert response.json()["cancelled"] is True + mock_enqueue.assert_called_once_with("sess-1") + + +# ─── session_assign_user ────────────────────────────────────────────── + + +def test_session_assign_user(mocker: pytest_mock.MockerFixture) -> None: + """PATCH /sessions/{id}/assign-user calls assign_user_to_session and returns ok.""" + mock_assign = mocker.patch( + "backend.api.features.chat.routes.chat_service.assign_user_to_session", + new_callable=AsyncMock, + return_value=None, + ) + + response = client.patch("/sessions/sess-1/assign-user") + + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + mock_assign.assert_called_once_with("sess-1", TEST_USER_ID) + + +# ─── get_ttl_config ────────────────────────────────────────────────── + + +def test_get_ttl_config(mocker: pytest_mock.MockerFixture) -> None: + """GET /config/ttl returns correct TTL values derived from config.""" + mocker.patch.object(chat_routes.config, "stream_ttl", 300) + + response = client.get("/config/ttl") + + assert response.status_code == 200 + data = response.json() + assert data["stream_ttl_seconds"] == 300 + assert data["stream_ttl_ms"] == 300_000 + + +# ─── reset_copilot_usage ────────────────────────────────────────────── + + +def _mock_reset_internals( + mocker: pytest_mock.MockerFixture, + *, + cost: int = 100, + enable_credit: bool = True, + daily_limit: int = 10_000, + weekly_limit: int = 50_000, + tier: "SubscriptionTier" = SubscriptionTier.BASIC, + daily_used: int = 10_001, + weekly_used: int = 1_000, + reset_count: int | None = 0, + acquire_lock: bool = True, + reset_daily: bool = True, + remaining_balance: int = 9_000, +): + """Set up all dependencies for reset_copilot_usage tests.""" + from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow + + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", cost) + mocker.patch.object(chat_routes.config, "max_daily_resets", 3) + mocker.patch.object(chat_routes.settings.config, "enable_credit", enable_credit) + + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(daily_limit, weekly_limit, tier), + ) + resets_at = datetime.now(UTC) + timedelta(hours=1) + status = CoPilotUsageStatus( + daily=UsageWindow(used=daily_used, limit=daily_limit, resets_at=resets_at), + weekly=UsageWindow(used=weekly_used, limit=weekly_limit, resets_at=resets_at), + ) + mocker.patch( + "backend.api.features.chat.routes.get_usage_status", + new_callable=AsyncMock, + return_value=status, + ) + mocker.patch( + "backend.api.features.chat.routes.get_daily_reset_count", + new_callable=AsyncMock, + return_value=reset_count, + ) + mocker.patch( + "backend.api.features.chat.routes.acquire_reset_lock", + new_callable=AsyncMock, + return_value=acquire_lock, + ) + mocker.patch( + "backend.api.features.chat.routes.release_reset_lock", + new_callable=AsyncMock, + ) + mocker.patch( + "backend.api.features.chat.routes.reset_daily_usage", + new_callable=AsyncMock, + return_value=reset_daily, + ) + mocker.patch( + "backend.api.features.chat.routes.increment_daily_reset_count", + new_callable=AsyncMock, + ) + + mock_credit_model = MagicMock() + mock_credit_model.spend_credits = AsyncMock(return_value=remaining_balance) + mock_credit_model.top_up_credits = AsyncMock(return_value=None) + mocker.patch( + "backend.api.features.chat.routes.get_user_credit_model", + new_callable=AsyncMock, + return_value=mock_credit_model, + ) + return mock_credit_model + + +def test_reset_usage_returns_400_when_cost_is_zero( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 400 when rate_limit_reset_cost <= 0.""" + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 0) + + response = client.post("/usage/reset") + + assert response.status_code == 400 + assert "not available" in response.json()["detail"].lower() + + +def test_reset_usage_returns_400_when_credits_disabled( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 400 when credit system is disabled.""" + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100) + mocker.patch.object(chat_routes.settings.config, "enable_credit", False) + + response = client.post("/usage/reset") + + assert response.status_code == 400 + assert "disabled" in response.json()["detail"].lower() + + +def test_reset_usage_returns_400_when_no_daily_limit( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 400 when daily_limit is 0.""" + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100) + mocker.patch.object(chat_routes.settings.config, "enable_credit", True) + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(0, 50_000, SubscriptionTier.BASIC), + ) + mocker.patch( + "backend.api.features.chat.routes.get_daily_reset_count", + new_callable=AsyncMock, + return_value=0, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 400 + assert "nothing to reset" in response.json()["detail"].lower() + + +def test_reset_usage_returns_503_when_redis_unavailable( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 503 when Redis is unavailable for reset count.""" + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100) + mocker.patch.object(chat_routes.settings.config, "enable_credit", True) + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(10_000, 50_000, SubscriptionTier.BASIC), + ) + mocker.patch( + "backend.api.features.chat.routes.get_daily_reset_count", + new_callable=AsyncMock, + return_value=None, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 503 + + +def test_reset_usage_returns_429_when_max_resets_reached( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 429 when max daily resets exceeded.""" + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100) + mocker.patch.object(chat_routes.config, "max_daily_resets", 2) + mocker.patch.object(chat_routes.settings.config, "enable_credit", True) + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(10_000, 50_000, SubscriptionTier.BASIC), + ) + mocker.patch( + "backend.api.features.chat.routes.get_daily_reset_count", + new_callable=AsyncMock, + return_value=2, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 429 + assert "resets" in response.json()["detail"].lower() + + +def test_reset_usage_returns_429_when_lock_not_acquired( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 429 when a concurrent reset is in progress.""" + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100) + mocker.patch.object(chat_routes.config, "max_daily_resets", 3) + mocker.patch.object(chat_routes.settings.config, "enable_credit", True) + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(10_000, 50_000, SubscriptionTier.BASIC), + ) + mocker.patch( + "backend.api.features.chat.routes.get_daily_reset_count", + new_callable=AsyncMock, + return_value=0, + ) + mocker.patch( + "backend.api.features.chat.routes.acquire_reset_lock", + new_callable=AsyncMock, + return_value=False, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 429 + assert "in progress" in response.json()["detail"].lower() + + +def test_reset_usage_returns_400_when_limit_not_reached( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 400 when daily limit has not been reached.""" + _mock_reset_internals(mocker, daily_used=500, daily_limit=10_000) + mocker.patch( + "backend.api.features.chat.routes.release_reset_lock", + new_callable=AsyncMock, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 400 + assert "not reached" in response.json()["detail"].lower() + + +def test_reset_usage_returns_400_when_weekly_also_exhausted( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 400 when weekly limit is also exhausted.""" + _mock_reset_internals( + mocker, + daily_used=10_001, + daily_limit=10_000, + weekly_used=50_001, + weekly_limit=50_000, + ) + mocker.patch( + "backend.api.features.chat.routes.release_reset_lock", + new_callable=AsyncMock, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 400 + assert "weekly" in response.json()["detail"].lower() + + +def test_reset_usage_returns_402_when_insufficient_credits( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 402 when credits are insufficient.""" + from backend.util.exceptions import InsufficientBalanceError + + mock_credit = _mock_reset_internals(mocker) + mock_credit.spend_credits = AsyncMock( + side_effect=InsufficientBalanceError( + message="Insufficient balance", + user_id=TEST_USER_ID, + balance=0.0, + amount=100.0, + ) + ) + mocker.patch( + "backend.api.features.chat.routes.release_reset_lock", + new_callable=AsyncMock, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 402 + + +def test_reset_usage_success(mocker: pytest_mock.MockerFixture) -> None: + """POST /usage/reset returns 200 with updated usage on success.""" + _mock_reset_internals(mocker, remaining_balance=8_900) + + response = client.post("/usage/reset") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["credits_charged"] == 100 + assert data["remaining_balance"] == 8_900 + assert "daily" in data["usage"] + assert "weekly" in data["usage"] + + +def test_reset_usage_refunds_on_redis_failure( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 503 and refunds credits when Redis reset fails.""" + mock_credit = _mock_reset_internals(mocker, reset_daily=False) + + response = client.post("/usage/reset") + + assert response.status_code == 503 + # Credits should be refunded via top_up_credits + mock_credit.top_up_credits.assert_called_once() + + +# ─── resume_session_stream ─────────────────────────────────────────── + + +def test_resume_session_stream_no_active_session( + mocker: pytest_mock.MockerFixture, +) -> None: + """GET /sessions/{id}/stream returns 204 when no active session.""" + mock_registry = MagicMock() + mock_registry.get_active_session = AsyncMock(return_value=(None, None)) + mocker.patch("backend.api.features.chat.routes.stream_registry", mock_registry) + + response = client.get("/sessions/sess-1/stream") + + assert response.status_code == 204 + + +def test_resume_session_stream_no_subscriber_queue( + mocker: pytest_mock.MockerFixture, +) -> None: + """GET /sessions/{id}/stream returns 204 when subscribe_to_session returns None.""" + from backend.copilot.stream_registry import ActiveSession + + active_session = ActiveSession( + session_id="sess-1", + user_id=TEST_USER_ID, + tool_call_id="chat_stream", + tool_name="chat", + turn_id="turn-1", + status="running", + ) + mock_registry = MagicMock() + mock_registry.get_active_session = AsyncMock(return_value=(active_session, "1-0")) + mock_registry.subscribe_to_session = AsyncMock(return_value=None) + mocker.patch("backend.api.features.chat.routes.stream_registry", mock_registry) + + response = client.get("/sessions/sess-1/stream") + + assert response.status_code == 204 + + # ─── DELETE /sessions/{id}/stream — disconnect listeners ────────────── @@ -1053,3 +1685,119 @@ def test_get_session_returns_backward_paginated( assert data["oldest_sequence"] == 0 assert "forward_paginated" not in data assert "newest_sequence" not in data + + +# ─── POST /sessions with builder_graph_id (get-or-create) ────────────── + + +def test_create_session_with_builder_graph_id_uses_get_or_create( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + """``POST /sessions`` with ``builder_graph_id`` routes through + ``get_or_create_builder_session`` and returns a session bound to the graph.""" + from backend.copilot.model import ChatSession + + async def _fake_get_or_create(user_id: str, graph_id: str) -> ChatSession: + return ChatSession.new( + user_id, + dry_run=False, + builder_graph_id=graph_id, + ) + + mocker.patch( + "backend.api.features.chat.routes.get_or_create_builder_session", + new_callable=AsyncMock, + side_effect=_fake_get_or_create, + ) + + response = client.post("/sessions", json={"builder_graph_id": "graph-1"}) + + assert response.status_code == 200 + body = response.json() + assert body["metadata"]["builder_graph_id"] == "graph-1" + assert body["metadata"]["dry_run"] is False + + +def test_create_session_with_builder_graph_id_returns_404_when_not_owned( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + """``get_or_create_builder_session`` raises ``NotFoundError`` when the + user doesn't own the graph; the route must map that to HTTP 404.""" + + async def _fake_get_or_create(user_id: str, graph_id: str): + raise NotFoundError(f"Graph {graph_id} not found") + + mocker.patch( + "backend.api.features.chat.routes.get_or_create_builder_session", + new_callable=AsyncMock, + side_effect=_fake_get_or_create, + ) + + response = client.post("/sessions", json={"builder_graph_id": "graph-unauthorized"}) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + +def test_create_session_without_builder_graph_id_creates_fresh( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + """With no ``builder_graph_id`` the endpoint falls through to the + default ``create_chat_session`` path — no get-or-create lookup.""" + from backend.copilot.model import ChatSession + + gorc = mocker.patch( + "backend.api.features.chat.routes.get_or_create_builder_session", + new_callable=AsyncMock, + ) + + async def _fake_create(user_id: str, *, dry_run: bool) -> ChatSession: + return ChatSession.new(user_id, dry_run=dry_run) + + mocker.patch( + "backend.api.features.chat.routes.create_chat_session", + new_callable=AsyncMock, + side_effect=_fake_create, + ) + + response = client.post("/sessions", json={"dry_run": True}) + + assert response.status_code == 200 + assert response.json()["metadata"]["dry_run"] is True + gorc.assert_not_called() + + +def test_create_session_rejects_unknown_fields( + test_user_id: str, +) -> None: + """Extra request fields are rejected (422) to prevent silent mis-use.""" + response = client.post("/sessions", json={"unexpected": "x"}) + assert response.status_code == 422 + + +def test_resolve_session_permissions_blocks_out_of_scope_tools() -> None: + """Builder-bound sessions return a blacklist of the three tools that + conflict with the panel's graph-bound scope. Regular sessions return + ``None`` so default (unrestricted) behaviour is preserved.""" + from backend.copilot.builder_context import BUILDER_BLOCKED_TOOLS + from backend.copilot.model import ChatSession + + unbound = ChatSession.new("u1", dry_run=False) + assert chat_routes.resolve_session_permissions(unbound) is None + + bound = ChatSession.new("u1", dry_run=False, builder_graph_id="g1") + perms = chat_routes.resolve_session_permissions(bound) + assert perms is not None + assert perms.tools_exclude is True # blacklist, not whitelist + assert sorted(perms.tools) == sorted(BUILDER_BLOCKED_TOOLS) + # Read-side lookups stay available — only write-scope / guide-dup are blocked. + assert "find_block" not in perms.tools + assert "find_agent" not in perms.tools + assert "search_docs" not in perms.tools + # The write tools (edit_agent / run_agent) are NOT blacklisted — they + # enforce scope per-tool via the builder_graph_id guard. + assert "edit_agent" not in perms.tools + assert "run_agent" not in perms.tools diff --git a/autogpt_platform/backend/backend/api/features/integrations/incremental_oauth_test.py b/autogpt_platform/backend/backend/api/features/integrations/incremental_oauth_test.py new file mode 100644 index 0000000000..352f14df5a --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/integrations/incremental_oauth_test.py @@ -0,0 +1,1130 @@ +"""Tests for incremental OAuth authorization (scope upgrade).""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import fastapi +import fastapi.testclient +import pytest +from pydantic import SecretStr + +from backend.api.features.integrations.router import router +from backend.data.model import APIKeyCredentials, OAuth2Credentials, OAuthState + +app = fastapi.FastAPI() +app.include_router(router) +client = fastapi.testclient.TestClient(app) + +TEST_USER_ID = "test-user-id" + + +def _make_google_oauth2_cred( + cred_id: str = "google-cred-1", + scopes: list[str] | None = None, + username: str = "alice@gmail.com", + title: str = "My Google", +) -> OAuth2Credentials: + return OAuth2Credentials( + id=cred_id, + provider="google", + title=title, + access_token=SecretStr("ya29.access-token"), + refresh_token=SecretStr("1//refresh-token"), + scopes=( + scopes + if scopes is not None + else ["https://www.googleapis.com/auth/gmail.readonly"] + ), + username=username, + access_token_expires_at=9999999999, + ) + + +def _make_github_oauth2_cred( + cred_id: str = "github-cred-1", + scopes: list[str] | None = None, + username: str = "alice", + title: str = "My GitHub", +) -> OAuth2Credentials: + return OAuth2Credentials( + id=cred_id, + provider="github", + title=title, + access_token=SecretStr("ghp_access_token"), + refresh_token=SecretStr("ghp_refresh_token"), + scopes=scopes if scopes is not None else ["repo"], + username=username, + ) + + +@pytest.fixture(autouse=True) +def setup_auth(mock_jwt_user): + from autogpt_libs.auth.jwt_utils import get_jwt_payload + + app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"] + yield + app.dependency_overrides.clear() + + +# ==================== OAuthState model tests ==================== # + + +class TestOAuthStateCredentialId: + """OAuthState model should support a credential_id field for upgrades.""" + + def test_oauth_state_accepts_credential_id(self): + state = OAuthState( + token="abc", + provider="google", + expires_at=9999999999, + scopes=["openid"], + credential_id="existing-cred-id", + ) + assert state.credential_id == "existing-cred-id" + + def test_oauth_state_defaults_credential_id_none(self): + state = OAuthState( + token="abc", + provider="google", + expires_at=9999999999, + scopes=["openid"], + ) + assert state.credential_id is None + + +# ==================== Login endpoint tests ==================== # + + +class TestIncrementalOAuthLogin: + """Tests for the login endpoint with credential_id parameter.""" + + def test_login_with_credential_id_stores_in_state(self): + """Login with credential_id should pass it through to store_state_token.""" + existing = _make_google_oauth2_cred() + handler = MagicMock() + handler.get_login_url.return_value = "https://accounts.google.com/auth" + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=existing) + mock_mgr.store.store_state_token = AsyncMock( + return_value=("state-token", "code-challenge") + ) + + resp = client.get( + "/google/login", + params={ + "scopes": "https://www.googleapis.com/auth/calendar.readonly", + "credential_id": "google-cred-1", + }, + ) + + assert resp.status_code == 200 + # Verify store_state_token was called with credential_id + call_kwargs = mock_mgr.store.store_state_token.call_args + assert call_kwargs.kwargs.get("credential_id") == "google-cred-1" or ( + len(call_kwargs.args) > 3 and call_kwargs.args[3] == "google-cred-1" + ) + + def test_login_github_unions_scopes_for_upgrade(self): + """For GitHub, login should request union of existing + new scopes.""" + existing = _make_github_oauth2_cred(scopes=["repo"]) + handler = MagicMock() + handler.get_login_url.return_value = "https://github.com/login/oauth/authorize" + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=existing) + mock_mgr.store.store_state_token = AsyncMock( + return_value=("state-token", "code-challenge") + ) + + resp = client.get( + "/github/login", + params={ + "scopes": "read:org", + "credential_id": "github-cred-1", + }, + ) + + assert resp.status_code == 200 + # The scopes passed to get_login_url should be the union + login_scopes = handler.get_login_url.call_args[0][0] + assert set(login_scopes) == {"repo", "read:org"} + + def test_login_google_keeps_requested_scopes_only(self): + """For Google, login should use only the new scopes (include_granted_scopes handles merging).""" + existing = _make_google_oauth2_cred( + scopes=["https://www.googleapis.com/auth/gmail.readonly"] + ) + handler = MagicMock() + handler.get_login_url.return_value = "https://accounts.google.com/auth" + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=existing) + mock_mgr.store.store_state_token = AsyncMock( + return_value=("state-token", "code-challenge") + ) + + resp = client.get( + "/google/login", + params={ + "scopes": "https://www.googleapis.com/auth/calendar.readonly", + "credential_id": "google-cred-1", + }, + ) + + assert resp.status_code == 200 + login_scopes = handler.get_login_url.call_args[0][0] + # Google should NOT union scopes in the login URL + assert "https://www.googleapis.com/auth/calendar.readonly" in login_scopes + assert "https://www.googleapis.com/auth/gmail.readonly" not in login_scopes + # Verify credential_id was passed through to store_state_token + call_kwargs = mock_mgr.store.store_state_token.call_args + assert call_kwargs.kwargs.get("credential_id") == "google-cred-1" + + def test_login_credential_not_found_returns_404(self): + handler = MagicMock() + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=None) + + resp = client.get( + "/google/login", + params={ + "scopes": "openid", + "credential_id": "nonexistent", + }, + ) + + assert resp.status_code == 404 + + def test_login_credential_provider_mismatch_returns_400(self): + """credential_id pointing to a Google cred when URL says github -> 400.""" + google_cred = _make_google_oauth2_cred() + handler = MagicMock() + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=google_cred) + + resp = client.get( + "/github/login", + params={ + "scopes": "repo", + "credential_id": "google-cred-1", + }, + ) + + assert resp.status_code == 400 + + def test_login_non_oauth2_credential_returns_400(self): + """credential_id pointing to an API key credential -> 400.""" + api_key_cred = APIKeyCredentials( + id="apikey-1", + provider="github", + title="API Key", + api_key=SecretStr("ghp_key"), + ) + handler = MagicMock() + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=api_key_cred) + + resp = client.get( + "/github/login", + params={ + "scopes": "repo", + "credential_id": "apikey-1", + }, + ) + + assert resp.status_code == 400 + + +# ==================== Callback endpoint tests ==================== # + + +class TestIncrementalOAuthCallback: + """Tests for the callback endpoint when upgrading credentials.""" + + def _make_state_with_credential_id( + self, + credential_id: str, + scopes: list[str] | None = None, + provider: str = "google", + ) -> OAuthState: + return OAuthState( + token="state-token", + provider=provider, + expires_at=9999999999, + scopes=( + scopes + if scopes is not None + else ["https://www.googleapis.com/auth/calendar.readonly"] + ), + credential_id=credential_id, + ) + + def test_callback_upgrades_existing_credential(self): + """When state has credential_id, should update existing credential.""" + existing = _make_google_oauth2_cred( + scopes=["https://www.googleapis.com/auth/gmail.readonly"] + ) + new_cred = _make_google_oauth2_cred( + scopes=[ + "https://www.googleapis.com/auth/gmail.readonly", + "https://www.googleapis.com/auth/calendar.readonly", + ] + ) + state = self._make_state_with_credential_id("google-cred-1") + handler = MagicMock() + handler.exchange_code_for_tokens = AsyncMock(return_value=new_cred) + handler.handle_default_scopes.return_value = state.scopes + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.verify_state_token = AsyncMock(return_value=state) + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=existing) + mock_mgr.update = AsyncMock() + mock_mgr.create = AsyncMock() + + resp = client.post( + "/google/callback", + json={"code": "auth-code", "state_token": "state-token"}, + ) + + assert resp.status_code == 200 + # Should call update, not create + mock_mgr.update.assert_called_once() + mock_mgr.create.assert_not_called() + + def test_callback_upgrade_merges_scopes(self): + """Upgraded credential should have union of old + new scopes.""" + existing = _make_google_oauth2_cred( + scopes=["https://www.googleapis.com/auth/gmail.readonly"] + ) + new_cred = _make_google_oauth2_cred( + scopes=[ + "https://www.googleapis.com/auth/gmail.readonly", + "https://www.googleapis.com/auth/calendar.readonly", + ] + ) + state = self._make_state_with_credential_id("google-cred-1") + handler = MagicMock() + handler.exchange_code_for_tokens = AsyncMock(return_value=new_cred) + handler.handle_default_scopes.return_value = state.scopes + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.verify_state_token = AsyncMock(return_value=state) + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=existing) + mock_mgr.update = AsyncMock() + + resp = client.post( + "/google/callback", + json={"code": "auth-code", "state_token": "state-token"}, + ) + + assert resp.status_code == 200 + data = resp.json() + assert set(data["scopes"]) == { + "https://www.googleapis.com/auth/gmail.readonly", + "https://www.googleapis.com/auth/calendar.readonly", + } + + def test_callback_upgrade_preserves_id_and_title(self): + """Upgraded credential should keep its original ID and title.""" + existing = _make_google_oauth2_cred( + cred_id="original-id", title="My Work Google" + ) + new_cred = _make_google_oauth2_cred(cred_id="new-id-from-exchange") + state = self._make_state_with_credential_id("original-id") + handler = MagicMock() + handler.exchange_code_for_tokens = AsyncMock(return_value=new_cred) + handler.handle_default_scopes.return_value = state.scopes + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.verify_state_token = AsyncMock(return_value=state) + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=existing) + mock_mgr.update = AsyncMock() + + resp = client.post( + "/google/callback", + json={"code": "auth-code", "state_token": "state-token"}, + ) + + assert resp.status_code == 200 + data = resp.json() + assert data["id"] == "original-id" + assert data["title"] == "My Work Google" + + def test_callback_upgrade_rejects_username_mismatch(self): + """Should reject if the new auth returns a different username.""" + existing = _make_google_oauth2_cred(username="alice@gmail.com") + new_cred = _make_google_oauth2_cred(username="bob@gmail.com") + state = self._make_state_with_credential_id("google-cred-1") + handler = MagicMock() + handler.exchange_code_for_tokens = AsyncMock(return_value=new_cred) + handler.handle_default_scopes.return_value = state.scopes + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.verify_state_token = AsyncMock(return_value=state) + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=existing) + + resp = client.post( + "/google/callback", + json={"code": "auth-code", "state_token": "state-token"}, + ) + + assert resp.status_code == 400 + assert "username" in resp.json()["detail"].lower() + + def test_callback_implicit_merge_same_provider_username(self): + """Without credential_id, should auto-merge when same provider+username exists.""" + existing = _make_google_oauth2_cred( + scopes=["https://www.googleapis.com/auth/gmail.readonly"] + ) + new_cred = _make_google_oauth2_cred( + cred_id="new-cred-id", + scopes=[ + "https://www.googleapis.com/auth/gmail.readonly", + "https://www.googleapis.com/auth/calendar.readonly", + ], + username="alice@gmail.com", + ) + # State WITHOUT credential_id + state = OAuthState( + token="state-token", + provider="google", + expires_at=9999999999, + scopes=["https://www.googleapis.com/auth/calendar.readonly"], + ) + handler = MagicMock() + handler.exchange_code_for_tokens = AsyncMock(return_value=new_cred) + handler.handle_default_scopes.return_value = state.scopes + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.verify_state_token = AsyncMock(return_value=state) + mock_mgr.store.get_creds_by_provider = AsyncMock(return_value=[existing]) + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=existing) + mock_mgr.update = AsyncMock() + mock_mgr.create = AsyncMock() + + resp = client.post( + "/google/callback", + json={"code": "auth-code", "state_token": "state-token"}, + ) + + assert resp.status_code == 200 + # Should update the existing credential, not create a new one + mock_mgr.update.assert_called_once() + mock_mgr.create.assert_not_called() + # The returned ID should be the existing credential's ID + data = resp.json() + assert data["id"] == "google-cred-1" + + def test_callback_no_implicit_merge_different_username(self): + """Without credential_id, different username should create new credential.""" + existing = _make_google_oauth2_cred(username="alice@gmail.com") + new_cred = _make_google_oauth2_cred( + cred_id="new-cred-id", + username="bob@gmail.com", + ) + state = OAuthState( + token="state-token", + provider="google", + expires_at=9999999999, + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + ) + handler = MagicMock() + handler.exchange_code_for_tokens = AsyncMock(return_value=new_cred) + handler.handle_default_scopes.return_value = state.scopes + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.verify_state_token = AsyncMock(return_value=state) + mock_mgr.store.get_creds_by_provider = AsyncMock(return_value=[existing]) + mock_mgr.create = AsyncMock() + mock_mgr.update = AsyncMock() + + resp = client.post( + "/google/callback", + json={"code": "auth-code", "state_token": "state-token"}, + ) + + assert resp.status_code == 200 + mock_mgr.create.assert_called_once() + mock_mgr.update.assert_not_called() + # Verify the implicit merge lookup was attempted + mock_mgr.store.get_creds_by_provider.assert_called_once() + + def test_callback_creates_new_when_no_existing(self): + """Without credential_id and no matching credential, creates new.""" + new_cred = _make_google_oauth2_cred() + state = OAuthState( + token="state-token", + provider="google", + expires_at=9999999999, + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + ) + handler = MagicMock() + handler.exchange_code_for_tokens = AsyncMock(return_value=new_cred) + handler.handle_default_scopes.return_value = state.scopes + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.verify_state_token = AsyncMock(return_value=state) + mock_mgr.store.get_creds_by_provider = AsyncMock(return_value=[]) + mock_mgr.create = AsyncMock() + mock_mgr.update = AsyncMock() + + resp = client.post( + "/google/callback", + json={"code": "auth-code", "state_token": "state-token"}, + ) + + assert resp.status_code == 200 + mock_mgr.create.assert_called_once() + mock_mgr.update.assert_not_called() + # Verify the implicit merge lookup was attempted + mock_mgr.store.get_creds_by_provider.assert_called_once() + + +# ==================== Round 2: Review feedback tests ==================== # + + +class TestManagedCredentialProtection: + """Managed/system credentials must not be upgradeable.""" + + def test_login_rejects_managed_credential_id(self): + """Explicit credential_id pointing to a managed credential -> 400.""" + managed = _make_google_oauth2_cred(cred_id="managed-1") + managed.is_managed = True + handler = MagicMock() + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=managed) + + resp = client.get( + "/google/login", + params={ + "scopes": "https://www.googleapis.com/auth/calendar.readonly", + "credential_id": "managed-1", + }, + ) + + assert resp.status_code == 400 + + def test_callback_rejects_upgrade_of_managed_credential(self): + """Callback with credential_id for a managed credential -> 400.""" + managed = _make_google_oauth2_cred(cred_id="managed-1") + managed.is_managed = True + new_cred = _make_google_oauth2_cred() + state = OAuthState( + token="state-token", + provider="google", + expires_at=9999999999, + scopes=["https://www.googleapis.com/auth/calendar.readonly"], + credential_id="managed-1", + ) + handler = MagicMock() + handler.exchange_code_for_tokens = AsyncMock(return_value=new_cred) + handler.handle_default_scopes.return_value = state.scopes + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.verify_state_token = AsyncMock(return_value=state) + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=managed) + + resp = client.post( + "/google/callback", + json={"code": "auth-code", "state_token": "state-token"}, + ) + + assert resp.status_code == 400 + + +class TestMetadataNoneGuard: + """Metadata merge must handle None values.""" + + def test_callback_upgrade_handles_none_metadata(self): + """Upgrading credential with metadata=None should not crash.""" + existing = _make_google_oauth2_cred( + scopes=["https://www.googleapis.com/auth/gmail.readonly"] + ) + existing.metadata = None # type: ignore[assignment] + new_cred = _make_google_oauth2_cred( + scopes=[ + "https://www.googleapis.com/auth/gmail.readonly", + "https://www.googleapis.com/auth/calendar.readonly", + ] + ) + new_cred.metadata = None # type: ignore[assignment] + state = OAuthState( + token="state-token", + provider="google", + expires_at=9999999999, + scopes=["https://www.googleapis.com/auth/calendar.readonly"], + credential_id="google-cred-1", + ) + handler = MagicMock() + handler.exchange_code_for_tokens = AsyncMock(return_value=new_cred) + handler.handle_default_scopes.return_value = state.scopes + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.verify_state_token = AsyncMock(return_value=state) + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=existing) + mock_mgr.update = AsyncMock() + + resp = client.post( + "/google/callback", + json={"code": "auth-code", "state_token": "state-token"}, + ) + + assert resp.status_code == 200 + + +class TestStateHelperScopesPattern: + """Test helper should handle empty scopes correctly.""" + + def test_make_state_preserves_empty_scopes(self): + """_make_state_with_credential_id([]) should keep empty list.""" + state_maker = TestIncrementalOAuthCallback() + state = state_maker._make_state_with_credential_id("cred-1", scopes=[]) + assert state.scopes == [] + + +class TestSystemCredentialProtection: + """Platform-owned system credentials must never be upgraded.""" + + def test_login_rejects_system_credential_id(self): + """Explicit credential_id pointing to a system credential -> 400.""" + handler = MagicMock() + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch( + "backend.api.features.integrations.router.is_system_credential", + return_value=True, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.get_creds_by_id = AsyncMock() + + resp = client.get( + "/google/login", + params={ + "scopes": "https://www.googleapis.com/auth/calendar.readonly", + "credential_id": "system-cred-id", + }, + ) + + assert resp.status_code == 400 + assert "system credentials" in resp.json()["detail"].lower() + # The store lookup must never happen for system credentials. + mock_mgr.store.get_creds_by_id.assert_not_called() + + def test_callback_rejects_upgrade_of_system_credential(self): + """Defense-in-depth: even if a stale login state points at a system + credential, the callback-time `_upgrade_existing_credential` must + reject it before persisting anything.""" + existing = _make_google_oauth2_cred(cred_id="sys-cred-id") + new_cred = _make_google_oauth2_cred( + scopes=[ + "https://www.googleapis.com/auth/gmail.readonly", + "https://www.googleapis.com/auth/calendar.readonly", + ] + ) + state = OAuthState( + token="state-token", + provider="google", + expires_at=9999999999, + scopes=["https://www.googleapis.com/auth/calendar.readonly"], + credential_id="sys-cred-id", + ) + handler = MagicMock() + handler.exchange_code_for_tokens = AsyncMock(return_value=new_cred) + handler.handle_default_scopes.return_value = state.scopes + + # is_system_credential returns True only when asked about "sys-cred-id" + # — emulating the real predicate that recognises platform-reserved IDs. + def _is_system(cred_id): + return cred_id == "sys-cred-id" + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch( + "backend.api.features.integrations.router.is_system_credential", + side_effect=_is_system, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.verify_state_token = AsyncMock(return_value=state) + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=existing) + mock_mgr.update = AsyncMock() + + resp = client.post( + "/google/callback", + json={"code": "auth-code", "state_token": "state-token"}, + ) + + assert resp.status_code == 400 + assert "system credentials" in resp.json()["detail"].lower() + # No write must have happened for the system credential. + mock_mgr.update.assert_not_called() + + def test_implicit_merge_skips_system_credentials(self): + """The implicit (provider+username) merge filter must exclude system + credentials so a user login cannot accidentally overwrite one.""" + system_match = _make_google_oauth2_cred( + cred_id="sys-cred-id", username="alice@gmail.com" + ) + new_cred = _make_google_oauth2_cred( + cred_id="new-cred-id", + scopes=system_match.scopes, + username="alice@gmail.com", + ) + state = OAuthState( + token="state-token", + provider="google", + expires_at=9999999999, + scopes=system_match.scopes, + ) + handler = MagicMock() + handler.exchange_code_for_tokens = AsyncMock(return_value=new_cred) + handler.handle_default_scopes.return_value = state.scopes + + def _is_system(cred_id): + return cred_id == "sys-cred-id" + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch( + "backend.api.features.integrations.router.is_system_credential", + side_effect=_is_system, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.verify_state_token = AsyncMock(return_value=state) + mock_mgr.store.get_creds_by_provider = AsyncMock( + return_value=[system_match] + ) + mock_mgr.create = AsyncMock() + mock_mgr.update = AsyncMock() + + resp = client.post( + "/google/callback", + json={"code": "auth-code", "state_token": "state-token"}, + ) + + assert resp.status_code == 200 + # Since the only provider+username match is a system credential, the + # callback must create a new credential rather than overwriting it. + mock_mgr.create.assert_called_once() + mock_mgr.update.assert_not_called() + + def test_upgrade_rejects_provider_mismatch(self): + """Defense-in-depth: if a stale login somehow passed validation but the + stored credential's provider no longer matches the new token's + provider, the write-path must refuse to overwrite it.""" + existing = _make_google_oauth2_cred(cred_id="mixed-up-cred") + # Simulate a provider drift: the new credential exchange returned a + # different provider than what's stored on disk. + new_cred = _make_github_oauth2_cred(cred_id="mixed-up-cred") + state = OAuthState( + token="state-token", + provider="google", + expires_at=9999999999, + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + credential_id="mixed-up-cred", + ) + handler = MagicMock() + handler.exchange_code_for_tokens = AsyncMock(return_value=new_cred) + handler.handle_default_scopes.return_value = state.scopes + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.verify_state_token = AsyncMock(return_value=state) + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=existing) + mock_mgr.update = AsyncMock() + + resp = client.post( + "/google/callback", + json={"code": "auth-code", "state_token": "state-token"}, + ) + + assert resp.status_code == 400 + assert "provider" in resp.json()["detail"].lower() + mock_mgr.update.assert_not_called() + + +class TestPreserveRefreshTokenAndUsername: + """Incremental callbacks must not silently drop refresh_token/username.""" + + def test_upgrade_preserves_existing_refresh_token_when_new_is_empty(self): + """If the new token response omits refresh_token, keep the existing one.""" + existing = _make_google_oauth2_cred( + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + ) + existing.refresh_token = SecretStr("original-refresh") + # Google may omit refresh_token on incremental re-authorization. + new_cred = _make_google_oauth2_cred( + scopes=[ + "https://www.googleapis.com/auth/gmail.readonly", + "https://www.googleapis.com/auth/calendar.readonly", + ], + ) + new_cred.refresh_token = None # type: ignore[assignment] + + state = OAuthState( + token="state-token", + provider="google", + expires_at=9999999999, + scopes=["https://www.googleapis.com/auth/calendar.readonly"], + credential_id="google-cred-1", + ) + handler = MagicMock() + handler.exchange_code_for_tokens = AsyncMock(return_value=new_cred) + handler.handle_default_scopes.return_value = state.scopes + + captured: dict[str, OAuth2Credentials] = {} + + async def _capture_update(_user_id, creds): + captured["creds"] = creds + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.verify_state_token = AsyncMock(return_value=state) + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=existing) + mock_mgr.update = AsyncMock(side_effect=_capture_update) + + resp = client.post( + "/google/callback", + json={"code": "auth-code", "state_token": "state-token"}, + ) + + assert resp.status_code == 200 + updated = captured["creds"] + assert updated.refresh_token is not None + assert updated.refresh_token.get_secret_value() == "original-refresh" + + def test_upgrade_preserves_existing_username_when_new_is_empty(self): + """If the new response lacks username, keep the existing one.""" + existing = _make_google_oauth2_cred(username="alice@gmail.com") + new_cred = _make_google_oauth2_cred(scopes=existing.scopes) + new_cred.username = None + + state = OAuthState( + token="state-token", + provider="google", + expires_at=9999999999, + scopes=existing.scopes, + credential_id="google-cred-1", + ) + handler = MagicMock() + handler.exchange_code_for_tokens = AsyncMock(return_value=new_cred) + handler.handle_default_scopes.return_value = state.scopes + + captured: dict[str, OAuth2Credentials] = {} + + async def _capture_update(_user_id, creds): + captured["creds"] = creds + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.verify_state_token = AsyncMock(return_value=state) + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=existing) + mock_mgr.update = AsyncMock(side_effect=_capture_update) + + resp = client.post( + "/google/callback", + json={"code": "auth-code", "state_token": "state-token"}, + ) + + assert resp.status_code == 200 + assert captured["creds"].username == "alice@gmail.com" + + +class TestImplicitMergeScopeGuard: + """Implicit (provider+username) merge must not advertise scopes wider than + the freshly-minted token actually grants.""" + + def _build_state(self, scopes: list[str]) -> OAuthState: + return OAuthState( + token="state-token", + provider="google", + expires_at=9999999999, + scopes=scopes, + ) + + def test_implicit_merge_skipped_when_new_scopes_narrower(self): + """If the new token doesn't cover all existing scopes, create a + fresh credential instead of overwriting the existing one.""" + existing = _make_google_oauth2_cred( + scopes=[ + "https://www.googleapis.com/auth/gmail.readonly", + "https://www.googleapis.com/auth/calendar.readonly", + ], + ) + # New login only requested gmail — narrower than existing. + new_cred = _make_google_oauth2_cred( + cred_id="new-cred-id", + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + ) + state = self._build_state(["https://www.googleapis.com/auth/gmail.readonly"]) + handler = MagicMock() + handler.exchange_code_for_tokens = AsyncMock(return_value=new_cred) + handler.handle_default_scopes.return_value = state.scopes + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.verify_state_token = AsyncMock(return_value=state) + mock_mgr.store.get_creds_by_provider = AsyncMock(return_value=[existing]) + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=existing) + mock_mgr.create = AsyncMock() + mock_mgr.update = AsyncMock() + + resp = client.post( + "/google/callback", + json={"code": "auth-code", "state_token": "state-token"}, + ) + + assert resp.status_code == 200 + mock_mgr.create.assert_called_once() + mock_mgr.update.assert_not_called() + + def test_implicit_merge_allowed_when_new_scopes_are_superset(self): + """If the new token covers every existing scope, the implicit merge + path can proceed as before.""" + existing = _make_google_oauth2_cred( + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + ) + new_cred = _make_google_oauth2_cred( + cred_id="new-cred-id", + scopes=[ + "https://www.googleapis.com/auth/gmail.readonly", + "https://www.googleapis.com/auth/calendar.readonly", + ], + ) + state = self._build_state( + [ + "https://www.googleapis.com/auth/gmail.readonly", + "https://www.googleapis.com/auth/calendar.readonly", + ] + ) + handler = MagicMock() + handler.exchange_code_for_tokens = AsyncMock(return_value=new_cred) + handler.handle_default_scopes.return_value = state.scopes + + with ( + patch( + "backend.api.features.integrations.router._get_provider_oauth_handler", + return_value=handler, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.verify_state_token = AsyncMock(return_value=state) + mock_mgr.store.get_creds_by_provider = AsyncMock(return_value=[existing]) + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=existing) + mock_mgr.create = AsyncMock() + mock_mgr.update = AsyncMock() + + resp = client.post( + "/google/callback", + json={"code": "auth-code", "state_token": "state-token"}, + ) + + assert resp.status_code == 200 + mock_mgr.update.assert_called_once() + mock_mgr.create.assert_not_called() + + +class TestUpgradeExistingCredentialDoesNotMutateCaller: + """Cursor Low (thread PRRT_kwDOJKSTjM58rern): ``_upgrade_existing_credential`` + used to mutate the caller's ``new_credentials`` object in-place + (overwriting id/title/scopes/metadata/refresh_token/username). Safe + today because all callers immediately replace their reference, but + fragile — a future reader of ``credentials`` after the call would + silently see overwritten values. Pin the contract so the caller's + object stays intact.""" + + @pytest.mark.asyncio + async def test_caller_credentials_object_is_unchanged_after_upgrade(self): + from backend.api.features.integrations.router import ( + _upgrade_existing_credential, + ) + + existing = _make_google_oauth2_cred( + cred_id="existing-cred-id", + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + username="alice@gmail.com", + title="Existing title", + ) + new_credentials = _make_google_oauth2_cred( + cred_id="new-cred-id-from-exchange", + scopes=[ + "https://www.googleapis.com/auth/gmail.readonly", + "https://www.googleapis.com/auth/calendar.readonly", + ], + username="alice@gmail.com", + title="New title from exchange", + ) + + # Snapshot the caller's object BEFORE the call so we can detect + # any in-place mutation by comparing afterwards. + snapshot = new_credentials.model_copy(deep=True) + + with ( + patch( + "backend.api.features.integrations.router.is_system_credential", + return_value=False, + ), + patch("backend.api.features.integrations.router.creds_manager") as mock_mgr, + ): + mock_mgr.store.get_creds_by_id = AsyncMock(return_value=existing) + mock_mgr.update = AsyncMock() + + returned = await _upgrade_existing_credential( + TEST_USER_ID, existing.id, new_credentials + ) + + # Caller's object must not have been touched — no id/title/scopes + # rewrite, no refresh_token/username/metadata mutation. + assert new_credentials.id == snapshot.id + assert new_credentials.title == snapshot.title + assert new_credentials.scopes == snapshot.scopes + assert new_credentials.metadata == snapshot.metadata + assert new_credentials.username == snapshot.username + assert ( + new_credentials.refresh_token.get_secret_value() + if new_credentials.refresh_token + else None + ) == ( + snapshot.refresh_token.get_secret_value() + if snapshot.refresh_token + else None + ) + + # The returned object carries the merged state, and is persisted. + assert returned.id == existing.id + assert set(returned.scopes) == { + "https://www.googleapis.com/auth/gmail.readonly", + "https://www.googleapis.com/auth/calendar.readonly", + } + mock_mgr.update.assert_called_once() diff --git a/autogpt_platform/backend/backend/api/features/integrations/router.py b/autogpt_platform/backend/backend/api/features/integrations/router.py index 1f97b5a987..2ac9f8cabe 100644 --- a/autogpt_platform/backend/backend/api/features/integrations/router.py +++ b/autogpt_platform/backend/backend/api/features/integrations/router.py @@ -14,7 +14,7 @@ from fastapi import ( Security, status, ) -from pydantic import BaseModel, Field, SecretStr, model_validator +from pydantic import BaseModel, Field, model_validator from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY from backend.api.features.library.db import set_preset_webhook, update_preset @@ -29,15 +29,14 @@ from backend.data.integrations import ( wait_for_webhook_event, ) from backend.data.model import ( + APIKeyCredentials, Credentials, CredentialsType, HostScopedCredentials, OAuth2Credentials, - UserIntegrations, is_sdk_default, ) from backend.data.onboarding import OnboardingStep, complete_onboarding_step -from backend.data.user import get_user_integrations from backend.executor.utils import add_graph_execution from backend.integrations.ayrshare import AyrshareClient, SocialPlatform from backend.integrations.credentials_store import ( @@ -48,7 +47,14 @@ from backend.integrations.creds_manager import ( IntegrationCredentialsManager, create_mcp_oauth_handler, ) -from backend.integrations.managed_credentials import ensure_managed_credentials +from backend.integrations.managed_credentials import ( + ensure_managed_credential, + ensure_managed_credentials, +) +from backend.integrations.managed_providers.ayrshare import AyrshareManagedProvider +from backend.integrations.managed_providers.ayrshare import ( + settings_available as ayrshare_settings_available, +) from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME from backend.integrations.providers import ProviderName from backend.integrations.webhooks import get_webhook_manager @@ -87,14 +93,23 @@ async def login( scopes: Annotated[ str, Query(title="Comma-separated list of authorization scopes") ] = "", + credential_id: Annotated[ + str | None, + Query(title="ID of existing credential to upgrade scopes for"), + ] = None, ) -> LoginResponse: handler = _get_provider_oauth_handler(request, provider) requested_scopes = scopes.split(",") if scopes else [] + if credential_id: + requested_scopes = await _prepare_scope_upgrade( + user_id, provider, credential_id, requested_scopes + ) + # Generate and store a secure random state token along with the scopes state_token, code_challenge = await creds_manager.store.store_state_token( - user_id, provider, requested_scopes + user_id, provider, requested_scopes, credential_id=credential_id ) login_url = handler.get_login_url( requested_scopes, state_token, code_challenge=code_challenge @@ -216,7 +231,9 @@ async def callback( ) # TODO: Allow specifying `title` to set on `credentials` - await creds_manager.create(user_id, credentials) + credentials = await _merge_or_create_credential( + user_id, provider, credentials, valid_state.credential_id + ) logger.debug( f"Successfully processed OAuth callback for user {user_id} " @@ -226,13 +243,38 @@ async def callback( return to_meta_response(credentials) +# Bound the first-time sweep so a slow upstream (e.g. Ayrshare) can't hang +# the credential-list endpoint. On timeout we still kick off a fire-and- +# forget sweep so provisioning eventually completes; the user just won't +# see the managed cred until the next refresh. +_MANAGED_PROVISION_TIMEOUT_S = 10.0 + + +async def _ensure_managed_credentials_bounded(user_id: str) -> None: + try: + await asyncio.wait_for( + ensure_managed_credentials(user_id, creds_manager.store), + timeout=_MANAGED_PROVISION_TIMEOUT_S, + ) + except asyncio.TimeoutError: + logger.warning( + "Managed credential sweep exceeded %.1fs for user=%s; " + "continuing without it — provisioning will complete in background", + _MANAGED_PROVISION_TIMEOUT_S, + user_id, + ) + asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store)) + + @router.get("/credentials", summary="List Credentials") async def list_credentials( user_id: Annotated[str, Security(get_user_id)], ) -> list[CredentialsMetaResponse]: - # Fire-and-forget: provision missing managed credentials in the background. - # The credential appears on the next page load; listing is never blocked. - asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store)) + # Block on provisioning so managed credentials appear on the first load + # instead of after a refresh, but with a timeout so a slow upstream + # can't hang the endpoint. `_provisioned_users` short-circuits on + # repeat calls. + await _ensure_managed_credentials_bounded(user_id) credentials = await creds_manager.store.get_all_creds(user_id) return [ @@ -247,7 +289,7 @@ async def list_credentials_by_provider( ], user_id: Annotated[str, Security(get_user_id)], ) -> list[CredentialsMetaResponse]: - asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store)) + await _ensure_managed_credentials_bounded(user_id) credentials = await creds_manager.store.get_creds_by_provider(user_id, provider) return [ @@ -281,6 +323,115 @@ async def get_credential( return to_meta_response(credential) +class PickerTokenResponse(BaseModel): + """Short-lived OAuth access token shipped to the browser for rendering a + provider-hosted picker UI (e.g. Google Drive Picker). Deliberately narrow: + only the fields the client needs to initialize the picker widget. Issued + from the user's own stored credential so ownership and scope gating are + enforced by the credential lookup.""" + + access_token: str = Field( + description="OAuth access token suitable for the picker SDK call." + ) + access_token_expires_at: int | None = Field( + default=None, + description="Unix timestamp at which the access token expires, if known.", + ) + + +# Allowlist of (provider, scopes) tuples that may mint picker tokens. Only +# Drive-picker-capable scopes qualify so a caller can't use this endpoint to +# extract a GitHub / other-provider OAuth token for unrelated purposes. If a +# future provider integrates a hosted picker that needs a raw access token, +# add its specific picker-relevant scopes here. +_PICKER_TOKEN_ALLOWED_SCOPES: dict[ProviderName, frozenset[str]] = { + ProviderName.GOOGLE: frozenset( + [ + "https://www.googleapis.com/auth/drive.file", + "https://www.googleapis.com/auth/drive.readonly", + "https://www.googleapis.com/auth/drive", + ] + ), +} + + +@router.post( + "/{provider}/credentials/{cred_id}/picker-token", + summary="Issue a short-lived access token for a provider-hosted picker", + operation_id="postV1GetPickerToken", +) +async def get_picker_token( + provider: Annotated[ + ProviderName, Path(title="The provider that owns the credentials") + ], + cred_id: Annotated[ + str, Path(title="The ID of the OAuth2 credentials to mint a token from") + ], + user_id: Annotated[str, Security(get_user_id)], +) -> PickerTokenResponse: + """Return the raw access token for an OAuth2 credential so the frontend + can initialize a provider-hosted picker (e.g. Google Drive Picker). + + `GET /{provider}/credentials/{cred_id}` deliberately strips secrets (see + `CredentialsMetaResponse` + `TestGetCredentialReturnsMetaOnly` in + `router_test.py`). That hardening broke the Drive picker, which needs the + raw access token to call `google.picker.Builder.setOAuthToken(...)`. This + endpoint carves a narrow, explicit hole: the caller must own the + credential, it must be OAuth2, and the endpoint returns only the access + token + its expiry — nothing else about the credential. SDK-default + credentials are excluded for the same reason as `get_credential`. + """ + if is_sdk_default(cred_id): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found" + ) + + credential = await creds_manager.get(user_id, cred_id) + if not credential: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found" + ) + if not provider_matches(credential.provider, provider): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found" + ) + if not isinstance(credential, OAuth2Credentials): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Picker tokens are only available for OAuth2 credentials", + ) + if not credential.access_token: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Credential has no access token; reconnect the account", + ) + + # Gate on provider+scope: only credentials that actually grant access to + # a provider-hosted picker flow may mint a token through this endpoint. + # Prevents using this path to extract bearer tokens for unrelated OAuth + # integrations (e.g. GitHub) that happen to be stored under the same user. + allowed_scopes = _PICKER_TOKEN_ALLOWED_SCOPES.get(provider) + if not allowed_scopes: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=(f"Picker tokens are not available for provider '{provider.value}'"), + ) + cred_scopes = set(credential.scopes or []) + if cred_scopes.isdisjoint(allowed_scopes): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=( + "Credential does not grant any scope eligible for the picker. " + "Reconnect with the appropriate scope." + ), + ) + + return PickerTokenResponse( + access_token=credential.access_token.get_secret_value(), + access_token_expires_at=credential.access_token_expires_at, + ) + + @router.post("/{provider}/credentials", status_code=201, summary="Create Credentials") async def create_credentials( user_id: Annotated[str, Security(get_user_id)], @@ -574,6 +725,186 @@ async def _execute_webhook_preset_trigger( # Continue processing - webhook should be resilient to individual failures +# -------------------- INCREMENTAL AUTH HELPERS -------------------- # + + +async def _prepare_scope_upgrade( + user_id: str, + provider: ProviderName, + credential_id: str, + requested_scopes: list[str], +) -> list[str]: + """Validate an existing credential for scope upgrade and compute scopes. + + For providers without native incremental auth (e.g. GitHub), returns the + union of existing + requested scopes. For providers that handle merging + server-side (e.g. Google with ``include_granted_scopes``), returns the + requested scopes unchanged. + + Raises HTTPException on validation failure. + """ + # Platform-owned system credentials must never be upgraded — scope + # changes here would leak across every user that shares them. + if is_system_credential(credential_id): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="System credentials cannot be upgraded", + ) + + existing = await creds_manager.store.get_creds_by_id(user_id, credential_id) + if not existing: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Credential to upgrade not found", + ) + if not isinstance(existing, OAuth2Credentials): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Only OAuth2 credentials can be upgraded", + ) + if not provider_matches(existing.provider, provider.value): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Credential provider does not match the requested provider", + ) + if existing.is_managed: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Managed credentials cannot be upgraded", + ) + + # Google handles scope merging via include_granted_scopes; others need + # the union of existing + new scopes in the login URL. + if provider != ProviderName.GOOGLE: + requested_scopes = list(set(requested_scopes) | set(existing.scopes)) + + return requested_scopes + + +async def _merge_or_create_credential( + user_id: str, + provider: ProviderName, + credentials: OAuth2Credentials, + credential_id: str | None, +) -> OAuth2Credentials: + """Either upgrade an existing credential or create a new one. + + When *credential_id* is set (explicit upgrade), merges scopes and updates + the existing credential. Otherwise, checks for an implicit merge (same + provider + username) before falling back to creating a new credential. + """ + if credential_id: + return await _upgrade_existing_credential(user_id, credential_id, credentials) + + # Implicit merge: check for existing credential with same provider+username. + # Skip managed/system credentials and require a non-None username on both + # sides so we never accidentally merge unrelated credentials. + if credentials.username is None: + await creds_manager.create(user_id, credentials) + return credentials + + existing_creds = await creds_manager.store.get_creds_by_provider(user_id, provider) + matching = next( + ( + c + for c in existing_creds + if isinstance(c, OAuth2Credentials) + and not c.is_managed + and not is_system_credential(c.id) + and c.username is not None + and c.username == credentials.username + ), + None, + ) + if matching: + # Only merge into the existing credential when the new token + # already covers every scope we're about to advertise on it. + # Without this guard we'd overwrite ``matching.access_token`` with + # a narrower token while storing a wider ``scopes`` list — the + # record would claim authorizations the token does not grant, and + # blocks using the lost scopes would fail with opaque 401/403s + # until the user hits re-auth. On a narrowing login, keep the + # two credentials separate instead. + if set(credentials.scopes).issuperset(set(matching.scopes)): + return await _upgrade_existing_credential(user_id, matching.id, credentials) + + await creds_manager.create(user_id, credentials) + return credentials + + +async def _upgrade_existing_credential( + user_id: str, + existing_cred_id: str, + new_credentials: OAuth2Credentials, +) -> OAuth2Credentials: + """Merge scopes from *new_credentials* into an existing credential.""" + # Defense-in-depth: re-check system and provider invariants right before + # the write. The login-time check in `_prepare_scope_upgrade` can go stale + # by the time the callback runs, and the implicit-merge path bypasses + # login-time validation entirely, so every write-path must enforce these + # on its own. + if is_system_credential(existing_cred_id): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="System credentials cannot be upgraded", + ) + existing = await creds_manager.store.get_creds_by_id(user_id, existing_cred_id) + if not existing or not isinstance(existing, OAuth2Credentials): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Credential to upgrade not found", + ) + if existing.is_managed: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Managed credentials cannot be upgraded", + ) + if not provider_matches(existing.provider, new_credentials.provider): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Credential provider does not match the requested provider", + ) + + if ( + existing.username + and new_credentials.username + and existing.username != new_credentials.username + ): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Username mismatch: authenticated as a different user", + ) + + # Operate on a copy so the caller's ``new_credentials`` object is not + # mutated out from under them. Every caller today immediately discards + # or replaces its reference, but the implicit-merge path in + # ``_merge_or_create_credential`` reads ``credentials.scopes`` before + # calling into us — a future reader after the call would otherwise + # silently see the overwritten values. + merged = new_credentials.model_copy(deep=True) + merged.id = existing.id + merged.title = existing.title + merged.scopes = list(set(existing.scopes) | set(new_credentials.scopes)) + merged.metadata = { + **(existing.metadata or {}), + **(new_credentials.metadata or {}), + } + # Preserve the existing refresh_token and username if the incremental + # response doesn't carry them. Providers like Google only return a + # refresh_token on first authorization — dropping it here would orphan + # the credential on the next access-token expiry, forcing the user to + # re-auth from scratch. Username is similarly sticky: if we've already + # resolved it for this credential, keep it rather than silently + # blanking it on an incremental upgrade. + if not merged.refresh_token and existing.refresh_token: + merged.refresh_token = existing.refresh_token + merged.refresh_token_expires_at = existing.refresh_token_expires_at + if not merged.username and existing.username: + merged.username = existing.username + await creds_manager.update(user_id, merged) + return merged + + # --------------------------- UTILITIES ---------------------------- # @@ -784,12 +1115,21 @@ def _get_provider_oauth_handler( async def get_ayrshare_sso_url( user_id: Annotated[str, Security(get_user_id)], ) -> AyrshareSSOResponse: - """ - Generate an SSO URL for Ayrshare social media integration. + """Generate a JWT SSO URL so the user can link their social accounts. - Returns: - dict: Contains the SSO URL for Ayrshare integration + The per-user Ayrshare profile key is provisioned and persisted as a + standard ``is_managed=True`` credential by + :class:`~backend.integrations.managed_providers.ayrshare.AyrshareManagedProvider`. + This endpoint only signs a short-lived JWT pointing at the Ayrshare- + hosted social-linking page; all profile lifecycle logic lives with the + managed provider. """ + if not ayrshare_settings_available(): + raise HTTPException( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + detail="Ayrshare integration is not configured", + ) + try: client = AyrshareClient() except MissingConfigError: @@ -798,66 +1138,63 @@ async def get_ayrshare_sso_url( detail="Ayrshare integration is not configured", ) - # Ayrshare profile key is stored in the credentials store - # It is generated when creating a new profile, if there is no profile key, - # we create a new profile and store the profile key in the credentials store - - user_integrations: UserIntegrations = await get_user_integrations(user_id) - profile_key = user_integrations.managed_credentials.ayrshare_profile_key - - if not profile_key: - logger.debug(f"Creating new Ayrshare profile for user {user_id}") - try: - profile = await client.create_profile( - title=f"User {user_id}", messaging_active=True - ) - profile_key = profile.profileKey - await creds_manager.store.set_ayrshare_profile_key(user_id, profile_key) - except Exception as e: - logger.error(f"Error creating Ayrshare profile for user {user_id}: {e}") - raise HTTPException( - status_code=HTTP_502_BAD_GATEWAY, - detail="Failed to create Ayrshare profile", - ) - else: - logger.debug(f"Using existing Ayrshare profile for user {user_id}") - - profile_key_str = ( - profile_key.get_secret_value() - if isinstance(profile_key, SecretStr) - else str(profile_key) + # On-demand provisioning: AyrshareManagedProvider opts out of the + # credentials sweep (profile quota is per-user subscription-bound). This + # endpoint is the only trigger that provisions a profile — one Ayrshare + # profile per user who actually opens the connect flow, not one per + # every authenticated user. + provisioned = await ensure_managed_credential( + user_id, creds_manager.store, AyrshareManagedProvider() ) + if not provisioned: + raise HTTPException( + status_code=HTTP_502_BAD_GATEWAY, + detail="Failed to provision Ayrshare profile", + ) + + ayrshare_creds = [ + c + for c in await creds_manager.store.get_creds_by_provider(user_id, "ayrshare") + if c.is_managed and isinstance(c, APIKeyCredentials) + ] + if not ayrshare_creds: + logger.error( + "Ayrshare credential provisioning did not produce a credential " + "for user %s", + user_id, + ) + raise HTTPException( + status_code=HTTP_502_BAD_GATEWAY, + detail="Failed to provision Ayrshare profile", + ) + profile_key_str = ayrshare_creds[0].api_key.get_secret_value() private_key = settings.secrets.ayrshare_jwt_key - # Ayrshare JWT expiry is 2880 minutes (48 hours) + # Ayrshare JWT max lifetime is 2880 minutes (48 h). max_expiry_minutes = 2880 try: - logger.debug(f"Generating Ayrshare JWT for user {user_id}") jwt_response = await client.generate_jwt( private_key=private_key, profile_key=profile_key_str, + # `allowed_social` is the set of networks the Ayrshare-hosted + # social-linking page will *offer* the user to connect. Blocks + # exist for more platforms than are listed here; the list is + # deliberately narrower so the rollout can verify each network + # end-to-end before widening the user-visible surface. Keep + # in sync with tested platforms — extend as each is verified + # against the block + Ayrshare's network-specific quirks. allowed_social=[ - # NOTE: We are enabling platforms one at a time - # to speed up the development process - # SocialPlatform.FACEBOOK, SocialPlatform.TWITTER, SocialPlatform.LINKEDIN, SocialPlatform.INSTAGRAM, SocialPlatform.YOUTUBE, - # SocialPlatform.REDDIT, - # SocialPlatform.TELEGRAM, - # SocialPlatform.GOOGLE_MY_BUSINESS, - # SocialPlatform.PINTEREST, SocialPlatform.TIKTOK, - # SocialPlatform.BLUESKY, - # SocialPlatform.SNAPCHAT, - # SocialPlatform.THREADS, ], expires_in=max_expiry_minutes, verify=True, ) - except Exception as e: - logger.error(f"Error generating Ayrshare JWT for user {user_id}: {e}") + except Exception as exc: + logger.error("Error generating Ayrshare JWT for user %s: %s", user_id, exc) raise HTTPException( status_code=HTTP_502_BAD_GATEWAY, detail="Failed to generate JWT" ) diff --git a/autogpt_platform/backend/backend/api/features/integrations/router_test.py b/autogpt_platform/backend/backend/api/features/integrations/router_test.py index 47f8b7a770..7c3a146aa9 100644 --- a/autogpt_platform/backend/backend/api/features/integrations/router_test.py +++ b/autogpt_platform/backend/backend/api/features/integrations/router_test.py @@ -393,7 +393,7 @@ class TestEnsureManagedCredentials: _PROVIDERS.update(saved) _provisioned_users.pop("user-1", None) - provider.provision.assert_awaited_once_with("user-1") + provider.provision.assert_awaited_once_with("user-1", store) store.add_managed_credential.assert_awaited_once_with("user-1", cred) @pytest.mark.asyncio @@ -568,3 +568,181 @@ class TestCleanupManagedCredentials: _PROVIDERS.update(saved) # No exception raised — cleanup failure is swallowed. + + +class TestGetPickerToken: + """POST /{provider}/credentials/{cred_id}/picker-token must: + 1. Return the access token for OAuth2 creds the caller owns. + 2. 404 for non-owned, non-existent, or wrong-provider creds. + 3. 400 for non-OAuth2 creds (API key, host-scoped, user/password). + 4. 404 for SDK default creds (same hardening as get_credential). + 5. Preserve the `TestGetCredentialReturnsMetaOnly` contract — the + existing meta-only endpoint must still strip secrets even after + this picker-token endpoint exists.""" + + def test_oauth2_owner_gets_access_token(self): + # Use a Google cred with a drive.file scope — only picker-eligible + # (provider, scope) pairs can mint a token. GitHub-style creds are + # explicitly rejected; see `test_non_picker_provider_rejected_as_400`. + cred = _make_oauth2_cred( + cred_id="cred-gdrive", + provider="google", + ) + cred.scopes = ["https://www.googleapis.com/auth/drive.file"] + with patch( + "backend.api.features.integrations.router.creds_manager" + ) as mock_mgr: + mock_mgr.get = AsyncMock(return_value=cred) + resp = client.post("/google/credentials/cred-gdrive/picker-token") + + assert resp.status_code == 200 + data = resp.json() + # The whole point of this endpoint: the access token IS returned here. + assert data["access_token"] == "ghp_secret_token" + # Only the two declared fields come back — nothing else leaks. + assert set(data.keys()) <= {"access_token", "access_token_expires_at"} + + def test_non_picker_provider_rejected_as_400(self): + """Provider allowlist: even with a valid OAuth2 credential, a + non-picker provider (GitHub, etc.) cannot mint a picker token. + Stops this endpoint from being used as a generic bearer-token + extraction path for any stored OAuth cred under the same user.""" + cred = _make_oauth2_cred(provider="github") + with patch( + "backend.api.features.integrations.router.creds_manager" + ) as mock_mgr: + mock_mgr.get = AsyncMock(return_value=cred) + resp = client.post("/github/credentials/cred-456/picker-token") + + assert resp.status_code == 400 + assert "not available for provider" in resp.json()["detail"] + assert "ghp_secret_token" not in str(resp.json()) + + def test_google_oauth_without_drive_scope_rejected(self): + """Scope allowlist: a Google OAuth2 cred that only carries non-picker + scopes (e.g. gmail.readonly, calendar) cannot mint a picker token. + Forces the frontend to reconnect with a Drive scope before the + picker is available.""" + cred = _make_oauth2_cred(provider="google") + cred.scopes = [ + "https://www.googleapis.com/auth/gmail.readonly", + "https://www.googleapis.com/auth/calendar", + ] + with patch( + "backend.api.features.integrations.router.creds_manager" + ) as mock_mgr: + mock_mgr.get = AsyncMock(return_value=cred) + resp = client.post("/google/credentials/cred-456/picker-token") + + assert resp.status_code == 400 + assert "picker" in resp.json()["detail"].lower() + + def test_api_key_credential_rejected_as_400(self): + cred = _make_api_key_cred() + with patch( + "backend.api.features.integrations.router.creds_manager" + ) as mock_mgr: + mock_mgr.get = AsyncMock(return_value=cred) + resp = client.post("/openai/credentials/cred-123/picker-token") + + assert resp.status_code == 400 + # API keys must not silently fall through to a 200 response of some + # other shape — the client should see a clear shape rejection. + body = str(resp.json()) + assert "sk-secret-key-value" not in body + + def test_user_password_credential_rejected_as_400(self): + cred = _make_user_password_cred() + with patch( + "backend.api.features.integrations.router.creds_manager" + ) as mock_mgr: + mock_mgr.get = AsyncMock(return_value=cred) + resp = client.post("/openai/credentials/cred-789/picker-token") + + assert resp.status_code == 400 + body = str(resp.json()) + assert "s3cret-pass" not in body + assert "admin" not in body + + def test_host_scoped_credential_rejected_as_400(self): + cred = _make_host_scoped_cred() + with patch( + "backend.api.features.integrations.router.creds_manager" + ) as mock_mgr: + mock_mgr.get = AsyncMock(return_value=cred) + resp = client.post("/openai/credentials/cred-host/picker-token") + + assert resp.status_code == 400 + assert "top-secret" not in str(resp.json()) + + def test_missing_credential_returns_404(self): + with patch( + "backend.api.features.integrations.router.creds_manager" + ) as mock_mgr: + mock_mgr.get = AsyncMock(return_value=None) + resp = client.post("/github/credentials/nonexistent/picker-token") + + assert resp.status_code == 404 + assert resp.json()["detail"] == "Credentials not found" + + def test_wrong_provider_returns_404(self): + """Symmetric with get_credential: provider mismatch is a generic + 404, not a 400, so we don't leak existence of a credential the + caller doesn't own on that provider.""" + cred = _make_oauth2_cred(provider="github") + with patch( + "backend.api.features.integrations.router.creds_manager" + ) as mock_mgr: + mock_mgr.get = AsyncMock(return_value=cred) + resp = client.post("/google/credentials/cred-456/picker-token") + + assert resp.status_code == 404 + assert resp.json()["detail"] == "Credentials not found" + + def test_sdk_default_returns_404(self): + """SDK defaults are invisible to the user-facing API — picker-token + must not mint a token for them either.""" + with patch( + "backend.api.features.integrations.router.creds_manager" + ) as mock_mgr: + mock_mgr.get = AsyncMock() + resp = client.post("/openai/credentials/openai-default/picker-token") + + assert resp.status_code == 404 + mock_mgr.get.assert_not_called() + + def test_oauth2_without_access_token_returns_400(self): + """A stored OAuth2 cred whose access_token is missing can't satisfy + a picker init. Surface a clear reconnect instruction rather than + returning an empty string.""" + cred = _make_oauth2_cred() + # Simulate a cred that lost its access token + object.__setattr__(cred, "access_token", None) + + with patch( + "backend.api.features.integrations.router.creds_manager" + ) as mock_mgr: + mock_mgr.get = AsyncMock(return_value=cred) + resp = client.post("/github/credentials/cred-456/picker-token") + + assert resp.status_code == 400 + assert "reconnect" in resp.json()["detail"].lower() + + def test_meta_only_endpoint_still_strips_access_token(self): + """Regression guard for the coexistence contract: the new + picker-token endpoint must NOT accidentally leak the token through + the meta-only GET endpoint. TestGetCredentialReturnsMetaOnly + covers this more broadly; this is a fast sanity check co-located + with the new endpoint's tests.""" + cred = _make_oauth2_cred() + with patch( + "backend.api.features.integrations.router.creds_manager" + ) as mock_mgr: + mock_mgr.get = AsyncMock(return_value=cred) + resp = client.get("/github/credentials/cred-456") + + assert resp.status_code == 200 + body = resp.json() + assert "access_token" not in body + assert "refresh_token" not in body + assert "ghp_secret_token" not in str(body) diff --git a/autogpt_platform/backend/backend/api/features/library/db.py b/autogpt_platform/backend/backend/api/features/library/db.py index 1e01ea638f..0e21edf061 100644 --- a/autogpt_platform/backend/backend/api/features/library/db.py +++ b/autogpt_platform/backend/backend/api/features/library/db.py @@ -743,6 +743,7 @@ async def update_library_agent_version_and_settings( graph=agent_graph, hitl_safe_mode=library.settings.human_in_the_loop_safe_mode, sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode, + builder_chat_session_id=library.settings.builder_chat_session_id, ) if updated_settings != library.settings: library = await update_library_agent( @@ -1803,7 +1804,7 @@ async def create_preset_from_graph_execution( raise NotFoundError( f"Graph #{graph_execution.graph_id} not found or accessible" ) - elif len(graph.aggregate_credentials_inputs()) > 0: + elif len(graph.regular_credentials_inputs) > 0: raise ValueError( f"Graph execution #{graph_exec_id} can't be turned into a preset " "because it was run before this feature existed " diff --git a/autogpt_platform/backend/backend/api/features/platform_linking/__init__.py b/autogpt_platform/backend/backend/api/features/platform_linking/__init__.py new file mode 100644 index 0000000000..7764686098 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/platform_linking/__init__.py @@ -0,0 +1 @@ +"""Platform bot linking — user-facing REST routes.""" diff --git a/autogpt_platform/backend/backend/api/features/platform_linking/routes.py b/autogpt_platform/backend/backend/api/features/platform_linking/routes.py new file mode 100644 index 0000000000..7b0f845c01 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/platform_linking/routes.py @@ -0,0 +1,158 @@ +"""User-facing platform_linking REST routes (JWT auth).""" + +import logging +from typing import Annotated + +from autogpt_libs import auth +from fastapi import APIRouter, HTTPException, Path, Security + +from backend.data.db_accessors import platform_linking_db +from backend.platform_linking.models import ( + ConfirmLinkResponse, + ConfirmUserLinkResponse, + DeleteLinkResponse, + LinkTokenInfoResponse, + PlatformLinkInfo, + PlatformUserLinkInfo, +) +from backend.util.exceptions import ( + LinkAlreadyExistsError, + LinkFlowMismatchError, + LinkTokenExpiredError, + NotAuthorizedError, + NotFoundError, +) + +logger = logging.getLogger(__name__) + +router = APIRouter() + +TokenPath = Annotated[ + str, + Path(max_length=64, pattern=r"^[A-Za-z0-9_-]+$"), +] + + +def _translate(exc: Exception) -> HTTPException: + if isinstance(exc, NotFoundError): + return HTTPException(status_code=404, detail=str(exc)) + if isinstance(exc, NotAuthorizedError): + return HTTPException(status_code=403, detail=str(exc)) + if isinstance(exc, LinkAlreadyExistsError): + return HTTPException(status_code=409, detail=str(exc)) + if isinstance(exc, LinkTokenExpiredError): + return HTTPException(status_code=410, detail=str(exc)) + if isinstance(exc, LinkFlowMismatchError): + return HTTPException(status_code=400, detail=str(exc)) + return HTTPException(status_code=500, detail="Internal error.") + + +@router.get( + "/tokens/{token}/info", + response_model=LinkTokenInfoResponse, + dependencies=[Security(auth.requires_user)], + summary="Get display info for a link token", +) +async def get_link_token_info_route(token: TokenPath) -> LinkTokenInfoResponse: + try: + return await platform_linking_db().get_link_token_info(token) + except (NotFoundError, LinkTokenExpiredError) as exc: + raise _translate(exc) from exc + + +@router.post( + "/tokens/{token}/confirm", + response_model=ConfirmLinkResponse, + dependencies=[Security(auth.requires_user)], + summary="Confirm a SERVER link token (user must be authenticated)", +) +async def confirm_link_token( + token: TokenPath, + user_id: Annotated[str, Security(auth.get_user_id)], +) -> ConfirmLinkResponse: + try: + return await platform_linking_db().confirm_server_link(token, user_id) + except ( + NotFoundError, + LinkFlowMismatchError, + LinkTokenExpiredError, + LinkAlreadyExistsError, + ) as exc: + raise _translate(exc) from exc + + +@router.post( + "/user-tokens/{token}/confirm", + response_model=ConfirmUserLinkResponse, + dependencies=[Security(auth.requires_user)], + summary="Confirm a USER link token (user must be authenticated)", +) +async def confirm_user_link_token( + token: TokenPath, + user_id: Annotated[str, Security(auth.get_user_id)], +) -> ConfirmUserLinkResponse: + try: + return await platform_linking_db().confirm_user_link(token, user_id) + except ( + NotFoundError, + LinkFlowMismatchError, + LinkTokenExpiredError, + LinkAlreadyExistsError, + ) as exc: + raise _translate(exc) from exc + + +@router.get( + "/links", + response_model=list[PlatformLinkInfo], + dependencies=[Security(auth.requires_user)], + summary="List all platform servers linked to the authenticated user", +) +async def list_my_links( + user_id: Annotated[str, Security(auth.get_user_id)], +) -> list[PlatformLinkInfo]: + return await platform_linking_db().list_server_links(user_id) + + +@router.get( + "/user-links", + response_model=list[PlatformUserLinkInfo], + dependencies=[Security(auth.requires_user)], + summary="List all DM links for the authenticated user", +) +async def list_my_user_links( + user_id: Annotated[str, Security(auth.get_user_id)], +) -> list[PlatformUserLinkInfo]: + return await platform_linking_db().list_user_links(user_id) + + +@router.delete( + "/links/{link_id}", + response_model=DeleteLinkResponse, + dependencies=[Security(auth.requires_user)], + summary="Unlink a platform server", +) +async def delete_link( + link_id: str, + user_id: Annotated[str, Security(auth.get_user_id)], +) -> DeleteLinkResponse: + try: + return await platform_linking_db().delete_server_link(link_id, user_id) + except (NotFoundError, NotAuthorizedError) as exc: + raise _translate(exc) from exc + + +@router.delete( + "/user-links/{link_id}", + response_model=DeleteLinkResponse, + dependencies=[Security(auth.requires_user)], + summary="Unlink a DM / user link", +) +async def delete_user_link_route( + link_id: str, + user_id: Annotated[str, Security(auth.get_user_id)], +) -> DeleteLinkResponse: + try: + return await platform_linking_db().delete_user_link(link_id, user_id) + except (NotFoundError, NotAuthorizedError) as exc: + raise _translate(exc) from exc diff --git a/autogpt_platform/backend/backend/api/features/platform_linking/routes_test.py b/autogpt_platform/backend/backend/api/features/platform_linking/routes_test.py new file mode 100644 index 0000000000..944ef8eb6a --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/platform_linking/routes_test.py @@ -0,0 +1,264 @@ +"""Route tests: domain exceptions → HTTPException status codes.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException + +from backend.util.exceptions import ( + LinkAlreadyExistsError, + LinkFlowMismatchError, + LinkTokenExpiredError, + NotAuthorizedError, + NotFoundError, +) + + +def _db_mock(**method_configs): + """Return a mock of the accessor's return value with the given AsyncMocks.""" + db = MagicMock() + for name, mock in method_configs.items(): + setattr(db, name, mock) + return db + + +class TestTokenInfoRouteTranslation: + @pytest.mark.asyncio + async def test_not_found_maps_to_404(self): + from backend.api.features.platform_linking.routes import ( + get_link_token_info_route, + ) + + db = _db_mock( + get_link_token_info=AsyncMock(side_effect=NotFoundError("missing")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as exc: + await get_link_token_info_route(token="abc") + assert exc.value.status_code == 404 + + @pytest.mark.asyncio + async def test_expired_maps_to_410(self): + from backend.api.features.platform_linking.routes import ( + get_link_token_info_route, + ) + + db = _db_mock( + get_link_token_info=AsyncMock(side_effect=LinkTokenExpiredError("expired")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as exc: + await get_link_token_info_route(token="abc") + assert exc.value.status_code == 410 + + +class TestConfirmLinkRouteTranslation: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "exc,expected_status", + [ + (NotFoundError("missing"), 404), + (LinkFlowMismatchError("wrong flow"), 400), + (LinkTokenExpiredError("expired"), 410), + (LinkAlreadyExistsError("already"), 409), + ], + ) + async def test_translation(self, exc: Exception, expected_status: int): + from backend.api.features.platform_linking.routes import confirm_link_token + + db = _db_mock(confirm_server_link=AsyncMock(side_effect=exc)) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as ctx: + await confirm_link_token(token="abc", user_id="u1") + assert ctx.value.status_code == expected_status + + +class TestConfirmUserLinkRouteTranslation: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "exc,expected_status", + [ + (NotFoundError("missing"), 404), + (LinkFlowMismatchError("wrong flow"), 400), + (LinkTokenExpiredError("expired"), 410), + (LinkAlreadyExistsError("already"), 409), + ], + ) + async def test_translation(self, exc: Exception, expected_status: int): + from backend.api.features.platform_linking.routes import confirm_user_link_token + + db = _db_mock(confirm_user_link=AsyncMock(side_effect=exc)) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as ctx: + await confirm_user_link_token(token="abc", user_id="u1") + assert ctx.value.status_code == expected_status + + +class TestDeleteLinkRouteTranslation: + @pytest.mark.asyncio + async def test_not_found_maps_to_404(self): + from backend.api.features.platform_linking.routes import delete_link + + db = _db_mock( + delete_server_link=AsyncMock(side_effect=NotFoundError("missing")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as exc: + await delete_link(link_id="x", user_id="u1") + assert exc.value.status_code == 404 + + @pytest.mark.asyncio + async def test_not_owned_maps_to_403(self): + from backend.api.features.platform_linking.routes import delete_link + + db = _db_mock( + delete_server_link=AsyncMock(side_effect=NotAuthorizedError("nope")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as exc: + await delete_link(link_id="x", user_id="u1") + assert exc.value.status_code == 403 + + +class TestDeleteUserLinkRouteTranslation: + @pytest.mark.asyncio + async def test_not_found_maps_to_404(self): + from backend.api.features.platform_linking.routes import delete_user_link_route + + db = _db_mock(delete_user_link=AsyncMock(side_effect=NotFoundError("missing"))) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as exc: + await delete_user_link_route(link_id="x", user_id="u1") + assert exc.value.status_code == 404 + + @pytest.mark.asyncio + async def test_not_owned_maps_to_403(self): + from backend.api.features.platform_linking.routes import delete_user_link_route + + db = _db_mock( + delete_user_link=AsyncMock(side_effect=NotAuthorizedError("nope")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as exc: + await delete_user_link_route(link_id="x", user_id="u1") + assert exc.value.status_code == 403 + + +# ── Adversarial: malformed token path params ────────────────────────── + + +class TestAdversarialTokenPath: + # TokenPath enforces `^[A-Za-z0-9_-]+$` + max_length=64. + + @pytest.fixture + def client(self): + import fastapi + from autogpt_libs.auth import get_user_id, requires_user + from fastapi.testclient import TestClient + + import backend.api.features.platform_linking.routes as routes_mod + + app = fastapi.FastAPI() + app.dependency_overrides[requires_user] = lambda: None + app.dependency_overrides[get_user_id] = lambda: "caller-user" + app.include_router(routes_mod.router, prefix="/api/platform-linking") + return TestClient(app) + + def test_rejects_token_with_special_chars(self, client): + response = client.get("/api/platform-linking/tokens/bad%24token/info") + assert response.status_code == 422 + + def test_rejects_token_with_path_traversal(self, client): + for probe in ("..%2F..", "foo..bar", "foo%2Fbar"): + response = client.get(f"/api/platform-linking/tokens/{probe}/info") + assert response.status_code in ( + 404, + 422, + ), f"path-traversal probe {probe!r} returned {response.status_code}" + + def test_rejects_token_too_long(self, client): + long_token = "a" * 65 + response = client.get(f"/api/platform-linking/tokens/{long_token}/info") + assert response.status_code == 422 + + def test_accepts_token_at_max_length(self, client): + token = "a" * 64 + db = _db_mock( + get_link_token_info=AsyncMock(side_effect=NotFoundError("missing")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + response = client.get(f"/api/platform-linking/tokens/{token}/info") + assert response.status_code == 404 + + def test_accepts_urlsafe_b64_token_shape(self, client): + db = _db_mock( + get_link_token_info=AsyncMock(side_effect=NotFoundError("missing")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + response = client.get("/api/platform-linking/tokens/abc-_XYZ123-_abc/info") + assert response.status_code == 404 + + def test_confirm_rejects_malformed_token(self, client): + response = client.post("/api/platform-linking/tokens/bad%24token/confirm") + assert response.status_code == 422 + + +class TestAdversarialDeleteLinkId: + """DELETE link_id has no regex — ensure weird values are handled via + NotFoundError (no crash, no cross-user leak).""" + + @pytest.fixture + def client(self): + import fastapi + from autogpt_libs.auth import get_user_id, requires_user + from fastapi.testclient import TestClient + + import backend.api.features.platform_linking.routes as routes_mod + + app = fastapi.FastAPI() + app.dependency_overrides[requires_user] = lambda: None + app.dependency_overrides[get_user_id] = lambda: "caller-user" + app.include_router(routes_mod.router, prefix="/api/platform-linking") + return TestClient(app) + + def test_weird_link_id_returns_404(self, client): + db = _db_mock( + delete_server_link=AsyncMock(side_effect=NotFoundError("missing")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + for link_id in ("'; DROP TABLE links;--", "../../etc/passwd", ""): + response = client.delete(f"/api/platform-linking/links/{link_id}") + assert response.status_code in (404, 405) diff --git a/autogpt_platform/backend/backend/api/features/store/db_test.py b/autogpt_platform/backend/backend/api/features/store/db_test.py index f3acd867d3..6d8cde4299 100644 --- a/autogpt_platform/backend/backend/api/features/store/db_test.py +++ b/autogpt_platform/backend/backend/api/features/store/db_test.py @@ -189,7 +189,7 @@ async def test_create_store_submission(mocker): notifyOnAgentApproved=True, notifyOnAgentRejected=True, timezone="Europe/Delft", - subscriptionTier=prisma.enums.SubscriptionTier.FREE, # type: ignore[reportCallIssue,reportAttributeAccessIssue] + subscriptionTier=prisma.enums.SubscriptionTier.BASIC, # type: ignore[reportCallIssue,reportAttributeAccessIssue] ) mock_agent = prisma.models.AgentGraph( id="agent-id", diff --git a/autogpt_platform/backend/backend/api/features/subscription_routes_test.py b/autogpt_platform/backend/backend/api/features/subscription_routes_test.py index c20e0d0ceb..e353a2e777 100644 --- a/autogpt_platform/backend/backend/api/features/subscription_routes_test.py +++ b/autogpt_platform/backend/backend/api/features/subscription_routes_test.py @@ -47,6 +47,62 @@ def _configure_frontend_origin(mocker: pytest_mock.MockFixture) -> None: ) +@pytest.fixture(autouse=True) +def _stub_pending_subscription_change(mocker: pytest_mock.MockFixture) -> None: + """Default pending-change lookup to None so tests don't hit Stripe/DB. + + Individual tests can override via their own mocker.patch call. + """ + mocker.patch( + "backend.api.features.v1.get_pending_subscription_change", + new_callable=AsyncMock, + return_value=None, + ) + + +_DEFAULT_TIER_PRICES: dict[SubscriptionTier, str | None] = { + SubscriptionTier.BASIC: None, # Legacy: stripe-price-id-basic unset by default. + SubscriptionTier.PRO: "price_pro", + SubscriptionTier.MAX: "price_max", + SubscriptionTier.BUSINESS: None, # Reserved: Business card hidden by default. +} + + +@pytest.fixture(autouse=True) +def _stub_subscription_status_lookups(mocker: pytest_mock.MockFixture) -> None: + """Stub Stripe price + proration + tier-multiplier lookups used by + get_subscription_status. + + The POST /credits/subscription handler now returns the full subscription + status payload from every branch (same-tier, BASIC downgrade, paid→paid + modify, checkout creation), so every POST test implicitly hits these + helpers. Individual tests can override via their own mocker.patch call. + """ + + async def default_price_id(tier: SubscriptionTier) -> str | None: + return _DEFAULT_TIER_PRICES.get(tier) + + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + side_effect=default_price_id, + ) + mocker.patch( + "backend.api.features.v1.get_proration_credit_cents", + new_callable=AsyncMock, + return_value=0, + ) + # Default tier-multiplier resolver to the backend defaults so the endpoint + # never reaches LaunchDarkly during tests. Individual tests override for + # LD-override scenarios. + from backend.copilot.rate_limit import _DEFAULT_TIER_MULTIPLIERS + + mocker.patch( + "backend.api.features.v1.get_tier_multipliers", + new_callable=AsyncMock, + return_value=dict(_DEFAULT_TIER_MULTIPLIERS), + ) + + @pytest.mark.parametrize( "url,expected", [ @@ -88,15 +144,28 @@ def test_get_subscription_status_pro( client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, ) -> None: - """GET /credits/subscription returns PRO tier with Stripe price for a PRO user.""" + """GET /credits/subscription returns PRO tier with Stripe prices for all priced tiers.""" mock_user = Mock() mock_user.subscription_tier = SubscriptionTier.PRO + prices = { + SubscriptionTier.BASIC: "price_basic", + SubscriptionTier.PRO: "price_pro", + SubscriptionTier.MAX: "price_max", + SubscriptionTier.BUSINESS: "price_business", + } + amounts = { + "price_basic": 0, + "price_pro": 1999, + "price_max": 4999, + "price_business": 14999, + } + async def mock_price_id(tier: SubscriptionTier) -> str | None: - return "price_pro" if tier == SubscriptionTier.PRO else None + return prices.get(tier) async def mock_stripe_price_amount(price_id: str) -> int: - return 1999 if price_id == "price_pro" else 0 + return amounts.get(price_id, 0) mocker.patch( "backend.api.features.v1.get_user_by_id", @@ -124,16 +193,63 @@ def test_get_subscription_status_pro( assert data["tier"] == "PRO" assert data["monthly_cost"] == 1999 assert data["tier_costs"]["PRO"] == 1999 - assert data["tier_costs"]["BUSINESS"] == 0 - assert data["tier_costs"]["FREE"] == 0 + assert data["tier_costs"]["MAX"] == 4999 + assert data["tier_costs"]["BUSINESS"] == 14999 + assert data["tier_costs"]["BASIC"] == 0 + assert "ENTERPRISE" not in data["tier_costs"] assert data["proration_credit_cents"] == 500 + # tier_multipliers mirrors the same set of tiers that land in tier_costs, + # so the frontend never renders a multiplier badge for a hidden row. + assert set(data["tier_multipliers"].keys()) == set(data["tier_costs"].keys()) + assert data["tier_multipliers"]["BASIC"] == 1.0 + assert data["tier_multipliers"]["PRO"] == 5.0 + assert data["tier_multipliers"]["MAX"] == 20.0 + assert data["tier_multipliers"]["BUSINESS"] == 60.0 -def test_get_subscription_status_defaults_to_free( +def test_get_subscription_status_tier_multipliers_ld_override( client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, ) -> None: - """GET /credits/subscription when subscription_tier is None defaults to FREE.""" + """A LaunchDarkly-overridden tier multiplier flows through the response.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.BASIC + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + + # LD says PRO is 7.5× (instead of the 5× default); other tiers unchanged. + mocker.patch( + "backend.api.features.v1.get_tier_multipliers", + new_callable=AsyncMock, + return_value={ + SubscriptionTier.BASIC: 1.0, + SubscriptionTier.PRO: 7.5, + SubscriptionTier.MAX: 20.0, + SubscriptionTier.BUSINESS: 60.0, + SubscriptionTier.ENTERPRISE: 60.0, + }, + ) + + response = client.get("/credits/subscription") + assert response.status_code == 200 + data = response.json() + # Only tiers that made it into tier_costs get a multiplier (default stub + # exposes PRO + MAX via _DEFAULT_TIER_PRICES). + assert data["tier_multipliers"]["PRO"] == 7.5 + assert data["tier_multipliers"]["MAX"] == 20.0 + # BUSINESS has no price configured → hidden from both maps. + assert "BUSINESS" not in data["tier_multipliers"] + + +def test_get_subscription_status_defaults_to_basic( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """When all LD price IDs are unset, tier_costs is empty and the caller sees cost=0.""" mock_user = Mock() mock_user.subscription_tier = None @@ -157,14 +273,9 @@ def test_get_subscription_status_defaults_to_free( assert response.status_code == 200 data = response.json() - assert data["tier"] == SubscriptionTier.FREE.value + assert data["tier"] == SubscriptionTier.BASIC.value assert data["monthly_cost"] == 0 - assert data["tier_costs"] == { - "FREE": 0, - "PRO": 0, - "BUSINESS": 0, - "ENTERPRISE": 0, - } + assert data["tier_costs"] == {} assert data["proration_credit_cents"] == 0 @@ -215,11 +326,11 @@ def test_get_subscription_status_stripe_error_falls_back_to_zero( assert data["tier_costs"]["PRO"] == 0 -def test_update_subscription_tier_free_no_payment( +def test_update_subscription_tier_basic_no_payment( client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, ) -> None: - """POST /credits/subscription to FREE tier when payment disabled skips Stripe.""" + """POST /credits/subscription to BASIC tier when payment disabled skips Stripe.""" mock_user = Mock() mock_user.subscription_tier = SubscriptionTier.PRO @@ -240,7 +351,7 @@ def test_update_subscription_tier_free_no_payment( new_callable=AsyncMock, ) - response = client.post("/credits/subscription", json={"tier": "FREE"}) + response = client.post("/credits/subscription", json={"tier": "BASIC"}) assert response.status_code == 200 assert response.json()["url"] == "" @@ -252,7 +363,7 @@ def test_update_subscription_tier_paid_beta_user( ) -> None: """POST /credits/subscription for paid tier when payment disabled returns 422.""" mock_user = Mock() - mock_user.subscription_tier = SubscriptionTier.FREE + mock_user.subscription_tier = SubscriptionTier.BASIC async def mock_feature_disabled(*args, **kwargs): return False @@ -279,7 +390,7 @@ def test_update_subscription_tier_paid_requires_urls( ) -> None: """POST /credits/subscription for paid tier without success/cancel URLs returns 422.""" mock_user = Mock() - mock_user.subscription_tier = SubscriptionTier.FREE + mock_user.subscription_tier = SubscriptionTier.BASIC async def mock_feature_enabled(*args, **kwargs): return True @@ -305,7 +416,7 @@ def test_update_subscription_tier_creates_checkout( ) -> None: """POST /credits/subscription creates Stripe Checkout Session for paid upgrade.""" mock_user = Mock() - mock_user.subscription_tier = SubscriptionTier.FREE + mock_user.subscription_tier = SubscriptionTier.BASIC async def mock_feature_enabled(*args, **kwargs): return True @@ -344,7 +455,7 @@ def test_update_subscription_tier_rejects_open_redirect( ) -> None: """POST /credits/subscription rejects success/cancel URLs outside the frontend origin.""" mock_user = Mock() - mock_user.subscription_tier = SubscriptionTier.FREE + mock_user.subscription_tier = SubscriptionTier.BASIC async def mock_feature_enabled(*args, **kwargs): return True @@ -407,30 +518,77 @@ def test_update_subscription_tier_enterprise_blocked( set_tier_mock.assert_not_awaited() -def test_update_subscription_tier_same_tier_is_noop( +def test_update_subscription_tier_same_tier_releases_pending_change( client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, ) -> None: - """POST /credits/subscription for the user's current paid tier returns 200 with empty URL. + """POST /credits/subscription for the user's current tier releases any pending change. - Without this guard a duplicate POST (double-click, browser retry, stale page) would - create a second Stripe Checkout Session for the same price, potentially billing the - user twice until the webhook reconciliation fires. + "Stay on my current tier" — the collapsed replacement for the old + /credits/subscription/cancel-pending route. Always calls + release_pending_subscription_schedule (idempotent when nothing is pending) + and returns the refreshed status with url="". Never creates a Checkout + Session — that would double-charge a user who double-clicks their own tier. """ mock_user = Mock() - mock_user.subscription_tier = SubscriptionTier.PRO - - async def mock_feature_enabled(*args, **kwargs): - return True + mock_user.subscription_tier = SubscriptionTier.BUSINESS mocker.patch( "backend.api.features.v1.get_user_by_id", new_callable=AsyncMock, return_value=mock_user, ) - mocker.patch( + release_mock = mocker.patch( + "backend.api.features.v1.release_pending_subscription_schedule", + new_callable=AsyncMock, + return_value=True, + ) + checkout_mock = mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + ) + feature_mock = mocker.patch( "backend.api.features.v1.is_feature_enabled", - side_effect=mock_feature_enabled, + new_callable=AsyncMock, + return_value=True, + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "BUSINESS", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["tier"] == "BUSINESS" + assert data["url"] == "" + release_mock.assert_awaited_once_with(TEST_USER_ID) + checkout_mock.assert_not_awaited() + # Same-tier branch short-circuits before the payment-flag check. + feature_mock.assert_not_awaited() + + +def test_update_subscription_tier_same_tier_no_pending_change_returns_status( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Same-tier request when nothing is pending still returns status with url="".""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + release_mock = mocker.patch( + "backend.api.features.v1.release_pending_subscription_schedule", + new_callable=AsyncMock, + return_value=False, ) checkout_mock = mocker.patch( "backend.api.features.v1.create_subscription_checkout", @@ -447,18 +605,58 @@ def test_update_subscription_tier_same_tier_is_noop( ) assert response.status_code == 200 - assert response.json()["url"] == "" + data = response.json() + assert data["tier"] == "PRO" + assert data["url"] == "" + assert data["pending_tier"] is None + release_mock.assert_awaited_once_with(TEST_USER_ID) checkout_mock.assert_not_awaited() -def test_update_subscription_tier_free_with_payment_schedules_cancel_and_does_not_update_db( +def test_update_subscription_tier_same_tier_stripe_error_returns_502( client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, ) -> None: - """Downgrading to FREE schedules Stripe cancellation at period end. + """Same-tier request surfaces a 502 when Stripe release fails. + + Carries forward the error contract from the removed + /credits/subscription/cancel-pending route so clients keep seeing 502 for + transient Stripe failures. + """ + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.release_pending_subscription_schedule", + side_effect=stripe.StripeError("network"), + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "BUSINESS", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 502 + assert "contact support" in response.json()["detail"].lower() + + +def test_update_subscription_tier_basic_with_payment_schedules_cancel_and_does_not_update_db( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Downgrading to BASIC schedules Stripe cancellation at period end. The DB tier must NOT be updated immediately — the customer.subscription.deleted - webhook fires at period end and downgrades to FREE then. + webhook fires at period end and downgrades to BASIC then. """ mock_user = Mock() mock_user.subscription_tier = SubscriptionTier.PRO @@ -484,18 +682,18 @@ def test_update_subscription_tier_free_with_payment_schedules_cancel_and_does_no side_effect=mock_feature_enabled, ) - response = client.post("/credits/subscription", json={"tier": "FREE"}) + response = client.post("/credits/subscription", json={"tier": "BASIC"}) assert response.status_code == 200 mock_cancel.assert_awaited_once() mock_set_tier.assert_not_awaited() -def test_update_subscription_tier_free_cancel_failure_returns_502( +def test_update_subscription_tier_basic_cancel_failure_returns_502( client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, ) -> None: - """Downgrading to FREE returns 502 with a generic error (no Stripe detail leakage).""" + """Downgrading to BASIC returns 502 with a generic error (no Stripe detail leakage).""" mock_user = Mock() mock_user.subscription_tier = SubscriptionTier.PRO @@ -518,7 +716,7 @@ def test_update_subscription_tier_free_cancel_failure_returns_502( side_effect=mock_feature_enabled, ) - response = client.post("/credits/subscription", json={"tier": "FREE"}) + response = client.post("/credits/subscription", json={"tier": "BASIC"}) assert response.status_code == 502 detail = response.json()["detail"] @@ -635,6 +833,16 @@ def test_update_subscription_tier_paid_to_paid_modifies_subscription( mock_user = Mock() mock_user.subscription_tier = SubscriptionTier.PRO + async def price_id_with_business(tier: SubscriptionTier) -> str | None: + return { + **_DEFAULT_TIER_PRICES, + SubscriptionTier.BUSINESS: "price_business", + }.get(tier) + + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + side_effect=price_id_with_business, + ) mocker.patch( "backend.api.features.v1.get_user_by_id", new_callable=AsyncMock, @@ -670,6 +878,49 @@ def test_update_subscription_tier_paid_to_paid_modifies_subscription( checkout_mock.assert_not_awaited() +def test_update_subscription_tier_max_checkout( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/subscription from PRO→MAX modifies the existing subscription.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + new_callable=AsyncMock, + return_value=True, + ) + modify_mock = mocker.patch( + "backend.api.features.v1.modify_stripe_subscription_for_tier", + new_callable=AsyncMock, + return_value=True, + ) + checkout_mock = mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "MAX", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 200 + assert response.json()["url"] == "" + modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.MAX) + checkout_mock.assert_not_awaited() + + def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly( client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, @@ -683,6 +934,16 @@ def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly mock_user = Mock() mock_user.subscription_tier = SubscriptionTier.PRO + async def price_id_with_business(tier: SubscriptionTier) -> str | None: + return { + **_DEFAULT_TIER_PRICES, + SubscriptionTier.BUSINESS: "price_business", + }.get(tier) + + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + side_effect=price_id_with_business, + ) mocker.patch( "backend.api.features.v1.get_user_by_id", new_callable=AsyncMock, @@ -725,6 +986,128 @@ def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly checkout_mock.assert_not_awaited() +def test_update_subscription_tier_priced_basic_no_sub_falls_through_to_checkout( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Once stripe-price-id-basic is configured, a BASIC user without an active sub + must hit Stripe Checkout rather than being silently set_subscription_tier'd.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.BASIC + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + return { + SubscriptionTier.BASIC: "price_basic", + SubscriptionTier.PRO: "price_pro", + SubscriptionTier.MAX: "price_max", + SubscriptionTier.BUSINESS: "price_business", + }.get(tier) + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + side_effect=mock_price_id, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + new_callable=AsyncMock, + return_value=True, + ) + modify_mock = mocker.patch( + "backend.api.features.v1.modify_stripe_subscription_for_tier", + new_callable=AsyncMock, + return_value=False, + ) + set_tier_mock = mocker.patch( + "backend.api.features.v1.set_subscription_tier", + new_callable=AsyncMock, + ) + checkout_mock = mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + return_value="https://checkout.stripe.com/pay/cs_test_priced_basic", + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "PRO", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 200 + assert ( + response.json()["url"] == "https://checkout.stripe.com/pay/cs_test_priced_basic" + ) + # Priced-BASIC user without an active sub: must NOT silently flip DB tier — + # they need to set up payment via Checkout. + set_tier_mock.assert_not_awaited() + checkout_mock.assert_awaited_once() + # modify is still called first; returning False just means "no active sub". + modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.PRO) + + +def test_update_subscription_tier_target_without_ld_price_returns_422( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Paid target with no LD-configured Stripe price must fail fast with 422. + + Matches the UI hiding: if `stripe-price-id-pro` resolves to None we can't + start a Checkout Session anyway, and we don't want to surface an opaque + Stripe error mid-flow. The handler rejects the request before touching + Stripe at all. + """ + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.BASIC + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + return None # Neither BASIC nor PRO have an LD price. + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + side_effect=mock_price_id, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + new_callable=AsyncMock, + return_value=True, + ) + checkout_mock = mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + ) + modify_mock = mocker.patch( + "backend.api.features.v1.modify_stripe_subscription_for_tier", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "PRO", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 422 + assert "not available" in response.json()["detail"].lower() + checkout_mock.assert_not_awaited() + modify_mock.assert_not_awaited() + + def test_update_subscription_tier_paid_to_paid_stripe_error_returns_502( client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, @@ -733,6 +1116,16 @@ def test_update_subscription_tier_paid_to_paid_stripe_error_returns_502( mock_user = Mock() mock_user.subscription_tier = SubscriptionTier.PRO + async def price_id_with_business(tier: SubscriptionTier) -> str | None: + return { + **_DEFAULT_TIER_PRICES, + SubscriptionTier.BUSINESS: "price_business", + }.get(tier) + + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + side_effect=price_id_with_business, + ) mocker.patch( "backend.api.features.v1.get_user_by_id", new_callable=AsyncMock, @@ -761,11 +1154,11 @@ def test_update_subscription_tier_paid_to_paid_stripe_error_returns_502( assert response.status_code == 502 -def test_update_subscription_tier_free_no_stripe_subscription( +def test_update_subscription_tier_basic_no_stripe_subscription( client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, ) -> None: - """Downgrading to FREE when no Stripe subscription exists updates DB tier directly. + """Downgrading to BASIC when no Stripe subscription exists updates DB tier directly. Admin-granted paid tiers have no associated Stripe subscription. When such a user requests a self-service downgrade, cancel_stripe_subscription returns False @@ -796,10 +1189,214 @@ def test_update_subscription_tier_free_no_stripe_subscription( new_callable=AsyncMock, ) - response = client.post("/credits/subscription", json={"tier": "FREE"}) + response = client.post("/credits/subscription", json={"tier": "BASIC"}) assert response.status_code == 200 assert response.json()["url"] == "" cancel_mock.assert_awaited_once_with(TEST_USER_ID) # DB tier must be updated immediately — no webhook will fire for a missing sub - set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.FREE) + set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BASIC) + + +def test_get_subscription_status_includes_pending_tier( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """GET /credits/subscription exposes pending_tier and pending_tier_effective_at.""" + import datetime as dt + + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + effective_at = dt.datetime(2030, 1, 1, tzinfo=dt.timezone.utc) + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + return None + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + side_effect=mock_price_id, + ) + mocker.patch( + "backend.api.features.v1.get_proration_credit_cents", + new_callable=AsyncMock, + return_value=0, + ) + mocker.patch( + "backend.api.features.v1.get_pending_subscription_change", + new_callable=AsyncMock, + return_value=(SubscriptionTier.PRO, effective_at), + ) + + response = client.get("/credits/subscription") + + assert response.status_code == 200 + data = response.json() + assert data["pending_tier"] == "PRO" + assert data["pending_tier_effective_at"] is not None + + +def test_get_subscription_status_no_pending_tier( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """When no pending change exists the response omits pending_tier.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + new_callable=AsyncMock, + return_value=None, + ) + mocker.patch( + "backend.api.features.v1.get_proration_credit_cents", + new_callable=AsyncMock, + return_value=0, + ) + mocker.patch( + "backend.api.features.v1.get_pending_subscription_change", + new_callable=AsyncMock, + return_value=None, + ) + + response = client.get("/credits/subscription") + + assert response.status_code == 200 + data = response.json() + assert data["pending_tier"] is None + assert data["pending_tier_effective_at"] is None + + +def test_update_subscription_tier_downgrade_paid_to_paid_schedules( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """A BUSINESS→PRO downgrade request dispatches to modify_stripe_subscription_for_tier.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + async def price_id_with_business(tier: SubscriptionTier) -> str | None: + return { + **_DEFAULT_TIER_PRICES, + SubscriptionTier.BUSINESS: "price_business", + }.get(tier) + + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + side_effect=price_id_with_business, + ) + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + new_callable=AsyncMock, + return_value=True, + ) + modify_mock = mocker.patch( + "backend.api.features.v1.modify_stripe_subscription_for_tier", + new_callable=AsyncMock, + return_value=True, + ) + checkout_mock = mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "PRO", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 200 + assert response.json()["url"] == "" + modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.PRO) + checkout_mock.assert_not_awaited() + + +def test_stripe_webhook_dispatches_subscription_schedule_released( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """subscription_schedule.released routes to sync_subscription_schedule_from_stripe.""" + schedule_obj = {"id": "sub_sched_1", "subscription": "sub_pro"} + event = { + "type": "subscription_schedule.released", + "data": {"object": schedule_obj}, + } + mocker.patch( + "backend.api.features.v1.settings.secrets.stripe_webhook_secret", + new="whsec_test", + ) + mocker.patch( + "backend.api.features.v1.stripe.Webhook.construct_event", + return_value=event, + ) + sync_mock = mocker.patch( + "backend.api.features.v1.sync_subscription_schedule_from_stripe", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/stripe_webhook", + content=b"{}", + headers={"stripe-signature": "t=1,v1=abc"}, + ) + + assert response.status_code == 200 + sync_mock.assert_awaited_once_with(schedule_obj) + + +def test_stripe_webhook_ignores_subscription_schedule_updated( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """subscription_schedule.updated must NOT dispatch: our own + SubscriptionSchedule.create/.modify calls fire this event and would + otherwise loop redundant traffic through the sync handler. State + transitions we care about surface via .released/.completed, and phase + advance to a new price is already covered by customer.subscription.updated. + """ + schedule_obj = {"id": "sub_sched_1", "subscription": "sub_pro"} + event = { + "type": "subscription_schedule.updated", + "data": {"object": schedule_obj}, + } + mocker.patch( + "backend.api.features.v1.settings.secrets.stripe_webhook_secret", + new="whsec_test", + ) + mocker.patch( + "backend.api.features.v1.stripe.Webhook.construct_event", + return_value=event, + ) + sync_mock = mocker.patch( + "backend.api.features.v1.sync_subscription_schedule_from_stripe", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/stripe_webhook", + content=b"{}", + headers={"stripe-signature": "t=1,v1=abc"}, + ) + + assert response.status_code == 200 + sync_mock.assert_not_awaited() diff --git a/autogpt_platform/backend/backend/api/features/v1.py b/autogpt_platform/backend/backend/api/features/v1.py index ab0b69071d..e47b05fa3d 100644 --- a/autogpt_platform/backend/backend/api/features/v1.py +++ b/autogpt_platform/backend/backend/api/features/v1.py @@ -26,10 +26,11 @@ from fastapi import ( ) from fastapi.concurrency import run_in_threadpool from prisma.enums import SubscriptionTier -from pydantic import BaseModel +from pydantic import BaseModel, Field from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND from typing_extensions import Optional, TypedDict +from backend.api.features.workspace.routes import create_file_download_response from backend.api.model import ( CreateAPIKeyRequest, CreateAPIKeyResponse, @@ -43,26 +44,31 @@ from backend.api.model import ( UploadFileResponse, ) from backend.blocks import get_block, get_blocks +from backend.copilot.rate_limit import get_tier_multipliers from backend.data import execution as execution_db from backend.data import graph as graph_db from backend.data.auth import api_key as api_key_db from backend.data.block import BlockInput, CompletedBlockOutput from backend.data.credit import ( AutoTopUpConfig, + PendingChangeUnknown, RefundRequest, TransactionHistory, UserCredit, cancel_stripe_subscription, create_subscription_checkout, get_auto_top_up, + get_pending_subscription_change, get_proration_credit_cents, get_subscription_price_id, get_user_credit_model, handle_subscription_payment_failure, modify_stripe_subscription_for_tier, + release_pending_subscription_schedule, set_auto_top_up, set_subscription_tier, sync_subscription_from_stripe, + sync_subscription_schedule_from_stripe, ) from backend.data.graph import GraphSettings from backend.data.model import CredentialsMetaInput, UserOnboarding @@ -92,6 +98,7 @@ from backend.data.user import ( update_user_notification_preference, update_user_timezone, ) +from backend.data.workspace import get_workspace_file_by_id from backend.executor import scheduler from backend.executor import utils as execution_utils from backend.integrations.webhooks.graph_lifecycle_hooks import ( @@ -693,20 +700,35 @@ async def get_user_auto_top_up( class SubscriptionTierRequest(BaseModel): - tier: Literal["FREE", "PRO", "BUSINESS"] + tier: Literal["BASIC", "PRO", "MAX", "BUSINESS"] success_url: str = "" cancel_url: str = "" -class SubscriptionCheckoutResponse(BaseModel): - url: str - - class SubscriptionStatusResponse(BaseModel): - tier: Literal["FREE", "PRO", "BUSINESS", "ENTERPRISE"] + tier: Literal["BASIC", "PRO", "MAX", "BUSINESS", "ENTERPRISE"] monthly_cost: int # amount in cents (Stripe convention) tier_costs: dict[str, int] # tier name -> amount in cents + tier_multipliers: dict[str, float] = Field( + default_factory=dict, + description=( + "Tier → rate-limit multiplier. Covers the same tiers listed in" + " ``tier_costs`` so the frontend can render rate-limit badges" + " relative to the lowest visible tier without knowing backend" + " defaults." + ), + ) proration_credit_cents: int # unused portion of current sub to convert on upgrade + pending_tier: Optional[Literal["BASIC", "PRO", "MAX", "BUSINESS"]] = None + pending_tier_effective_at: Optional[datetime] = None + url: str = Field( + default="", + description=( + "Populated only when POST /credits/subscription starts a Stripe Checkout" + " Session (BASIC → paid upgrade). Empty string in all other branches —" + " the client redirects to this URL when non-empty." + ), + ) def _validate_checkout_redirect_url(url: str) -> bool: @@ -782,39 +804,80 @@ async def get_subscription_status( user_id: Annotated[str, Security(get_user_id)], ) -> SubscriptionStatusResponse: user = await get_user_by_id(user_id) - tier = user.subscription_tier or SubscriptionTier.FREE + tier = user.subscription_tier or SubscriptionTier.BASIC - paid_tiers = [SubscriptionTier.PRO, SubscriptionTier.BUSINESS] + priceable_tiers = [ + SubscriptionTier.BASIC, + SubscriptionTier.PRO, + SubscriptionTier.MAX, + SubscriptionTier.BUSINESS, + ] price_ids = await asyncio.gather( - *[get_subscription_price_id(t) for t in paid_tiers] + *[get_subscription_price_id(t) for t in priceable_tiers] ) - tier_costs: dict[str, int] = { - SubscriptionTier.FREE.value: 0, - SubscriptionTier.ENTERPRISE.value: 0, - } - async def _cost(pid: str | None) -> int: return (await _get_stripe_price_amount(pid) or 0) if pid else 0 costs = await asyncio.gather(*[_cost(pid) for pid in price_ids]) - for t, cost in zip(paid_tiers, costs): - tier_costs[t.value] = cost + + tier_costs: dict[str, int] = {} + for t, pid, cost in zip(priceable_tiers, price_ids, costs): + if pid: + tier_costs[t.value] = cost + + # Expose the effective rate-limit multipliers alongside prices so the + # frontend can render "Nx rate limits" relative to the lowest visible + # tier without hard-coding backend defaults. Only emit entries for tiers + # that land in ``tier_costs`` — rows hidden at the price layer must stay + # hidden in the multiplier layer too. + multipliers = await get_tier_multipliers() + tier_multipliers: dict[str, float] = { + t.value: multipliers.get(t, 1.0) + for t in priceable_tiers + if t.value in tier_costs + } current_monthly_cost = tier_costs.get(tier.value, 0) proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost) - return SubscriptionStatusResponse( + try: + pending = await get_pending_subscription_change(user_id) + except (stripe.StripeError, PendingChangeUnknown): + # Swallow Stripe-side failures (rate limits, transient network) AND + # PendingChangeUnknown (LaunchDarkly price-id lookup failed). Both + # propagate past the cache so the next request retries fresh instead + # of serving a stale None for the TTL window. Let real bugs (KeyError, + # AttributeError, etc.) propagate so they surface in Sentry. + logger.exception( + "get_subscription_status: failed to resolve pending change for user %s", + user_id, + ) + pending = None + + response = SubscriptionStatusResponse( tier=tier.value, monthly_cost=current_monthly_cost, tier_costs=tier_costs, + tier_multipliers=tier_multipliers, proration_credit_cents=proration_credit, ) + if pending is not None: + pending_tier_enum, pending_effective_at = pending + if pending_tier_enum in ( + SubscriptionTier.BASIC, + SubscriptionTier.PRO, + SubscriptionTier.MAX, + SubscriptionTier.BUSINESS, + ): + response.pending_tier = pending_tier_enum.value + response.pending_tier_effective_at = pending_effective_at + return response @v1_router.post( path="/credits/subscription", - summary="Start a Stripe Checkout session to upgrade subscription tier", + summary="Update subscription tier or start a Stripe Checkout session", operation_id="updateSubscriptionTier", tags=["credits"], dependencies=[Security(requires_user)], @@ -822,38 +885,63 @@ async def get_subscription_status( async def update_subscription_tier( request: SubscriptionTierRequest, user_id: Annotated[str, Security(get_user_id)], -) -> SubscriptionCheckoutResponse: - # Pydantic validates tier is one of FREE/PRO/BUSINESS via Literal type. +) -> SubscriptionStatusResponse: + # Pydantic validates tier is one of BASIC/PRO/MAX/BUSINESS via Literal type. tier = SubscriptionTier(request.tier) # ENTERPRISE tier is admin-managed — block self-service changes from ENTERPRISE users. user = await get_user_by_id(user_id) - if (user.subscription_tier or SubscriptionTier.FREE) == SubscriptionTier.ENTERPRISE: + if ( + user.subscription_tier or SubscriptionTier.BASIC + ) == SubscriptionTier.ENTERPRISE: raise HTTPException( status_code=403, detail="ENTERPRISE subscription changes must be managed by an administrator", ) + # Same-tier request = "stay on my current tier" = cancel any pending + # scheduled change (paid→paid downgrade or paid→BASIC cancel). This is the + # collapsed behaviour that replaces the old /credits/subscription/cancel-pending + # route. Safe when no pending change exists: release_pending_subscription_schedule + # returns False and we simply return the current status. + if (user.subscription_tier or SubscriptionTier.BASIC) == tier: + try: + await release_pending_subscription_schedule(user_id) + except stripe.StripeError as e: + logger.exception( + "Stripe error releasing pending subscription change for user %s: %s", + user_id, + e, + ) + raise HTTPException( + status_code=502, + detail=( + "Unable to cancel the pending subscription change right now. " + "Please try again or contact support." + ), + ) + return await get_subscription_status(user_id) + payment_enabled = await is_feature_enabled( Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False ) - # Downgrade to FREE: schedule Stripe cancellation at period end so the user - # keeps their tier for the time they already paid for. The DB tier is NOT - # updated here when a subscription exists — the customer.subscription.deleted - # webhook fires at period end and downgrades to FREE then. - # Exception: if the user has no active Stripe subscription (e.g. admin-granted - # tier), cancel_stripe_subscription returns False and we update the DB tier - # immediately since no webhook will ever fire. - # When payment is disabled entirely, update the DB tier directly. - if tier == SubscriptionTier.FREE: + current_tier = user.subscription_tier or SubscriptionTier.BASIC + target_price_id, current_tier_price_id = await asyncio.gather( + get_subscription_price_id(tier), + get_subscription_price_id(current_tier), + ) + + # Legacy cancel: target BASIC + stripe-price-id-basic unset. Schedule Stripe + # cancellation at period end; cancel_at_period_end=True lets the webhook flip + # the DB tier. No active sub (admin-granted) or payment disabled → DB flip. + # Once stripe-price-id-basic is configured, BASIC becomes a real sub and falls + # through to the modify/checkout flow below. + if tier == SubscriptionTier.BASIC and target_price_id is None: if payment_enabled: try: had_subscription = await cancel_stripe_subscription(user_id) except stripe.StripeError as e: - # Log full Stripe error server-side but return a generic message - # to the client — raw Stripe errors can leak customer/sub IDs and - # infrastructure config details. logger.exception( "Stripe error cancelling subscription for user %s: %s", user_id, @@ -867,48 +955,37 @@ async def update_subscription_tier( ), ) if not had_subscription: - # No active Stripe subscription found — the user was on an - # admin-granted tier. Update DB immediately since the - # subscription.deleted webhook will never fire. await set_subscription_tier(user_id, tier) - return SubscriptionCheckoutResponse(url="") + return await get_subscription_status(user_id) await set_subscription_tier(user_id, tier) - return SubscriptionCheckoutResponse(url="") + return await get_subscription_status(user_id) - # Paid tier changes require payment to be enabled — block self-service upgrades - # when the flag is off. Admins use the /api/admin/ routes to set tiers directly. if not payment_enabled: raise HTTPException( status_code=422, - detail=f"Subscription not available for tier {tier}", + detail=f"Subscription not available for tier {tier.value}", ) - # No-op short-circuit: if the user is already on the requested paid tier, - # do NOT create a new Checkout Session. Without this guard, a duplicate - # request (double-click, retried POST, stale page) creates a second - # subscription for the same price; the user would be charged for both - # until `_cleanup_stale_subscriptions` runs from the resulting webhook — - # which only fires after the second charge has cleared. - if (user.subscription_tier or SubscriptionTier.FREE) == tier: - return SubscriptionCheckoutResponse(url="") + # Target has no LD price — not provisionable (matches the GET hiding). + if target_price_id is None: + raise HTTPException( + status_code=422, + detail=f"Subscription not available for tier {tier.value}", + ) - # Paid→paid tier change: if the user already has a Stripe subscription, - # modify it in-place with proration instead of creating a new Checkout - # Session. This preserves remaining paid time and avoids double-charging. - # The customer.subscription.updated webhook fires and updates the DB tier. - current_tier = user.subscription_tier or SubscriptionTier.FREE - if current_tier in (SubscriptionTier.PRO, SubscriptionTier.BUSINESS): + # User has an active Stripe subscription (current tier has an LD price): + # modify it in-place. modify_stripe_subscription_for_tier returns False when no + # active sub exists — that's only a "DB-only flip is OK" signal for admin-granted + # paid tiers (PRO/BUSINESS with no Stripe record). Priced-BASIC users without a + # sub must still go through Checkout so they set up payment. + if current_tier_price_id is not None: try: modified = await modify_stripe_subscription_for_tier(user_id, tier) if modified: - return SubscriptionCheckoutResponse(url="") - # modify_stripe_subscription_for_tier returns False when no active - # Stripe subscription exists — i.e. the user has an admin-granted - # paid tier with no Stripe record. In that case, update the DB - # tier directly (same as the FREE-downgrade path for admin-granted - # users) rather than sending them through a new Checkout Session. - await set_subscription_tier(user_id, tier) - return SubscriptionCheckoutResponse(url="") + return await get_subscription_status(user_id) + if current_tier != SubscriptionTier.BASIC: + await set_subscription_tier(user_id, tier) + return await get_subscription_status(user_id) except ValueError as e: raise HTTPException(status_code=422, detail=str(e)) except stripe.StripeError as e: @@ -923,7 +1000,7 @@ async def update_subscription_tier( ), ) - # Paid upgrade from FREE → create Stripe Checkout Session. + # No active Stripe subscription → create Stripe Checkout Session. if not request.success_url or not request.cancel_url: raise HTTPException( status_code=422, @@ -978,7 +1055,9 @@ async def update_subscription_tier( ), ) - return SubscriptionCheckoutResponse(url=url) + status = await get_subscription_status(user_id) + status.url = url + return status @v1_router.post( @@ -1043,6 +1122,18 @@ async def stripe_webhook(request: Request): ): await sync_subscription_from_stripe(data_object) + # `subscription_schedule.updated` is deliberately omitted: our own + # `SubscriptionSchedule.create` + `.modify` calls in + # `_schedule_downgrade_at_period_end` would fire that event right back at us + # and loop redundant traffic through this handler. We only care about state + # transitions (released / completed); phase advance to the new price is + # already covered by `customer.subscription.updated`. + if event_type in ( + "subscription_schedule.released", + "subscription_schedule.completed", + ): + await sync_subscription_schedule_from_stripe(data_object) + if event_type == "invoice.payment_failed": await handle_subscription_payment_failure(data_object) @@ -1640,6 +1731,10 @@ async def enable_execution_sharing( # Generate a unique share token share_token = str(uuid.uuid4()) + # Remove stale allowlist records before updating the token — prevents a + # window where old records + new token could coexist. + await execution_db.delete_shared_execution_files(execution_id=graph_exec_id) + # Update the execution with share info await execution_db.update_graph_execution_share_status( execution_id=graph_exec_id, @@ -1649,6 +1744,14 @@ async def enable_execution_sharing( shared_at=datetime.now(timezone.utc), ) + # Create allowlist of workspace files referenced in outputs + await execution_db.create_shared_execution_files( + execution_id=graph_exec_id, + share_token=share_token, + user_id=user_id, + outputs=execution.outputs, + ) + # Return the share URL frontend_url = settings.config.frontend_base_url or "http://localhost:3000" share_url = f"{frontend_url}/share/{share_token}" @@ -1674,6 +1777,9 @@ async def disable_execution_sharing( if not execution: raise HTTPException(status_code=404, detail="Execution not found") + # Remove shared file allowlist records + await execution_db.delete_shared_execution_files(execution_id=graph_exec_id) + # Remove share info await execution_db.update_graph_execution_share_status( execution_id=graph_exec_id, @@ -1699,6 +1805,43 @@ async def get_shared_execution( return execution +@v1_router.get( + "/public/shared/{share_token}/files/{file_id}/download", + summary="Download a file from a shared execution", + operation_id="download_shared_file", + tags=["graphs"], +) +async def download_shared_file( + share_token: Annotated[ + str, + Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"), + ], + file_id: Annotated[ + str, + Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"), + ], +) -> Response: + """Download a workspace file from a shared execution (no auth required). + + Validates that the file was explicitly exposed when sharing was enabled. + Returns a uniform 404 for all failure modes to prevent enumeration attacks. + """ + # Single-query validation against the allowlist + execution_id = await execution_db.get_shared_execution_file( + share_token=share_token, file_id=file_id + ) + if not execution_id: + raise HTTPException(status_code=404, detail="Not found") + + # Look up the actual file (no workspace scoping needed — the allowlist + # already validated that this file belongs to the shared execution) + file = await get_workspace_file_by_id(file_id) + if not file: + raise HTTPException(status_code=404, detail="Not found") + + return await create_file_download_response(file, inline=True) + + ######################################################## ##################### Schedules ######################## ######################################################## diff --git a/autogpt_platform/backend/backend/api/features/v1_share_test.py b/autogpt_platform/backend/backend/api/features/v1_share_test.py new file mode 100644 index 0000000000..de5d14ad80 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/v1_share_test.py @@ -0,0 +1,157 @@ +"""Tests for the public shared file download endpoint.""" + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from starlette.responses import Response + +from backend.api.features.v1 import v1_router +from backend.data.workspace import WorkspaceFile + +app = FastAPI() +app.include_router(v1_router, prefix="/api") + +VALID_TOKEN = "550e8400-e29b-41d4-a716-446655440000" +VALID_FILE_ID = "6ba7b810-9dad-11d1-80b4-00c04fd430c8" + + +def _make_workspace_file(**overrides) -> WorkspaceFile: + defaults = { + "id": VALID_FILE_ID, + "workspace_id": "ws-001", + "created_at": datetime(2026, 1, 1, tzinfo=timezone.utc), + "updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc), + "name": "image.png", + "path": "/image.png", + "storage_path": "local://uploads/image.png", + "mime_type": "image/png", + "size_bytes": 4, + "checksum": None, + "is_deleted": False, + "deleted_at": None, + "metadata": {}, + } + defaults.update(overrides) + return WorkspaceFile(**defaults) + + +def _mock_download_response(**kwargs): + """Return an AsyncMock that resolves to a Response with inline disposition.""" + + async def _handler(file, *, inline=False): + return Response( + content=b"\x89PNG", + media_type="image/png", + headers={ + "Content-Disposition": ( + 'inline; filename="image.png"' + if inline + else 'attachment; filename="image.png"' + ), + "Content-Length": "4", + }, + ) + + return _handler + + +class TestDownloadSharedFile: + """Tests for GET /api/public/shared/{token}/files/{id}/download.""" + + @pytest.fixture(autouse=True) + def _client(self): + self.client = TestClient(app, raise_server_exceptions=False) + + def test_valid_token_and_file_returns_inline_content(self): + with ( + patch( + "backend.api.features.v1.execution_db.get_shared_execution_file", + new_callable=AsyncMock, + return_value="exec-123", + ), + patch( + "backend.api.features.v1.get_workspace_file_by_id", + new_callable=AsyncMock, + return_value=_make_workspace_file(), + ), + patch( + "backend.api.features.v1.create_file_download_response", + side_effect=_mock_download_response(), + ), + ): + response = self.client.get( + f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download" + ) + + assert response.status_code == 200 + assert response.content == b"\x89PNG" + assert "inline" in response.headers["Content-Disposition"] + + def test_invalid_token_format_returns_422(self): + response = self.client.get( + f"/api/public/shared/not-a-uuid/files/{VALID_FILE_ID}/download" + ) + assert response.status_code == 422 + + def test_token_not_in_allowlist_returns_404(self): + with patch( + "backend.api.features.v1.execution_db.get_shared_execution_file", + new_callable=AsyncMock, + return_value=None, + ): + response = self.client.get( + f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download" + ) + assert response.status_code == 404 + + def test_file_missing_from_workspace_returns_404(self): + with ( + patch( + "backend.api.features.v1.execution_db.get_shared_execution_file", + new_callable=AsyncMock, + return_value="exec-123", + ), + patch( + "backend.api.features.v1.get_workspace_file_by_id", + new_callable=AsyncMock, + return_value=None, + ), + ): + response = self.client.get( + f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download" + ) + assert response.status_code == 404 + + def test_uniform_404_prevents_enumeration(self): + """Both failure modes produce identical 404 — no information leak.""" + with patch( + "backend.api.features.v1.execution_db.get_shared_execution_file", + new_callable=AsyncMock, + return_value=None, + ): + resp_no_allow = self.client.get( + f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download" + ) + + with ( + patch( + "backend.api.features.v1.execution_db.get_shared_execution_file", + new_callable=AsyncMock, + return_value="exec-123", + ), + patch( + "backend.api.features.v1.get_workspace_file_by_id", + new_callable=AsyncMock, + return_value=None, + ), + ): + resp_no_file = self.client.get( + f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download" + ) + + assert resp_no_allow.status_code == 404 + assert resp_no_file.status_code == 404 + assert resp_no_allow.json() == resp_no_file.json() diff --git a/autogpt_platform/backend/backend/api/features/workspace/routes.py b/autogpt_platform/backend/backend/api/features/workspace/routes.py index b96e953491..247749f895 100644 --- a/autogpt_platform/backend/backend/api/features/workspace/routes.py +++ b/autogpt_platform/backend/backend/api/features/workspace/routes.py @@ -31,7 +31,9 @@ from backend.util.workspace import WorkspaceManager from backend.util.workspace_storage import get_workspace_storage -def _sanitize_filename_for_header(filename: str) -> str: +def _sanitize_filename_for_header( + filename: str, disposition: str = "attachment" +) -> str: """ Sanitize filename for Content-Disposition header to prevent header injection. @@ -46,11 +48,11 @@ def _sanitize_filename_for_header(filename: str) -> str: # Check if filename has non-ASCII characters try: sanitized.encode("ascii") - return f'attachment; filename="{sanitized}"' + return f'{disposition}; filename="{sanitized}"' except UnicodeEncodeError: # Use RFC5987 encoding for UTF-8 filenames encoded = quote(sanitized, safe="") - return f"attachment; filename*=UTF-8''{encoded}" + return f"{disposition}; filename*=UTF-8''{encoded}" logger = logging.getLogger(__name__) @@ -60,19 +62,26 @@ router = fastapi.APIRouter( ) -def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response: +def _create_streaming_response( + content: bytes, file: WorkspaceFile, *, inline: bool = False +) -> Response: """Create a streaming response for file content.""" + disposition = _sanitize_filename_for_header( + file.name, disposition="inline" if inline else "attachment" + ) return Response( content=content, media_type=file.mime_type, headers={ - "Content-Disposition": _sanitize_filename_for_header(file.name), + "Content-Disposition": disposition, "Content-Length": str(len(content)), }, ) -async def _create_file_download_response(file: WorkspaceFile) -> Response: +async def create_file_download_response( + file: WorkspaceFile, *, inline: bool = False +) -> Response: """ Create a download response for a workspace file. @@ -84,7 +93,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response: # For local storage, stream the file directly if file.storage_path.startswith("local://"): content = await storage.retrieve(file.storage_path) - return _create_streaming_response(content, file) + return _create_streaming_response(content, file, inline=inline) # For GCS, try to redirect to signed URL, fall back to streaming try: @@ -92,7 +101,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response: # If we got back an API path (fallback), stream directly instead if url.startswith("/api/"): content = await storage.retrieve(file.storage_path) - return _create_streaming_response(content, file) + return _create_streaming_response(content, file, inline=inline) return fastapi.responses.RedirectResponse(url=url, status_code=302) except Exception as e: # Log the signed URL failure with context @@ -104,7 +113,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response: # Fall back to streaming directly from GCS try: content = await storage.retrieve(file.storage_path) - return _create_streaming_response(content, file) + return _create_streaming_response(content, file, inline=inline) except Exception as fallback_error: logger.error( f"Fallback streaming also failed for file {file.id} " @@ -171,7 +180,7 @@ async def download_file( if file is None: raise fastapi.HTTPException(status_code=404, detail="File not found") - return await _create_file_download_response(file) + return await create_file_download_response(file) @router.delete( diff --git a/autogpt_platform/backend/backend/api/features/workspace/routes_test.py b/autogpt_platform/backend/backend/api/features/workspace/routes_test.py index 37cfcf90da..5c00e9a9f2 100644 --- a/autogpt_platform/backend/backend/api/features/workspace/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/workspace/routes_test.py @@ -630,3 +630,221 @@ def test_get_storage_usage_returns_tier_based_limit(mocker): assert data["limit_bytes"] == 1024 * 1024 * 1024 assert data["used_bytes"] == 100 * 1024 * 1024 assert data["file_count"] == 5 + + +# -- _sanitize_filename_for_header tests -- + + +class TestSanitizeFilenameForHeader: + def test_simple_ascii_attachment(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + assert _sanitize_filename_for_header("report.pdf") == ( + 'attachment; filename="report.pdf"' + ) + + def test_inline_disposition(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + assert _sanitize_filename_for_header("image.png", disposition="inline") == ( + 'inline; filename="image.png"' + ) + + def test_strips_cr_lf_null(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + result = _sanitize_filename_for_header("a\rb\nc\x00d.txt") + assert "\r" not in result + assert "\n" not in result + assert "\x00" not in result + assert 'filename="abcd.txt"' in result + + def test_escapes_quotes(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + result = _sanitize_filename_for_header('file"name.txt') + assert 'filename="file\\"name.txt"' in result + + def test_header_injection_blocked(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + result = _sanitize_filename_for_header("evil.txt\r\nX-Injected: true") + # CR/LF stripped — the remaining text is safely inside the quoted value + assert "\r" not in result + assert "\n" not in result + assert result == 'attachment; filename="evil.txtX-Injected: true"' + + def test_unicode_uses_rfc5987(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + result = _sanitize_filename_for_header("日本語.pdf") + assert "filename*=UTF-8''" in result + assert "attachment" in result + + def test_unicode_inline(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + result = _sanitize_filename_for_header("图片.png", disposition="inline") + assert result.startswith("inline; filename*=UTF-8''") + + def test_empty_filename(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + result = _sanitize_filename_for_header("") + assert result == 'attachment; filename=""' + + +# -- _create_streaming_response tests -- + + +class TestCreateStreamingResponse: + def test_attachment_disposition_by_default(self): + from backend.api.features.workspace.routes import _create_streaming_response + + file = _make_file(name="data.bin", mime_type="application/octet-stream") + response = _create_streaming_response(b"binary-data", file) + assert ( + response.headers["Content-Disposition"] == 'attachment; filename="data.bin"' + ) + assert response.headers["Content-Type"] == "application/octet-stream" + assert response.headers["Content-Length"] == "11" + assert response.body == b"binary-data" + + def test_inline_disposition(self): + from backend.api.features.workspace.routes import _create_streaming_response + + file = _make_file(name="photo.png", mime_type="image/png") + response = _create_streaming_response(b"\x89PNG", file, inline=True) + assert response.headers["Content-Disposition"] == 'inline; filename="photo.png"' + assert response.headers["Content-Type"] == "image/png" + + def test_inline_sanitizes_filename(self): + from backend.api.features.workspace.routes import _create_streaming_response + + file = _make_file(name='evil"\r\n.txt', mime_type="text/plain") + response = _create_streaming_response(b"data", file, inline=True) + assert "\r" not in response.headers["Content-Disposition"] + assert "\n" not in response.headers["Content-Disposition"] + assert "inline" in response.headers["Content-Disposition"] + + def test_content_length_matches_body(self): + from backend.api.features.workspace.routes import _create_streaming_response + + content = b"x" * 1000 + file = _make_file(name="big.bin", mime_type="application/octet-stream") + response = _create_streaming_response(content, file) + assert response.headers["Content-Length"] == "1000" + + +# -- create_file_download_response tests -- + + +class TestCreateFileDownloadResponse: + @pytest.mark.asyncio + async def test_local_storage_returns_streaming_response(self, mocker): + from backend.api.features.workspace.routes import create_file_download_response + + mock_storage = AsyncMock() + mock_storage.retrieve.return_value = b"file contents" + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_storage", + return_value=mock_storage, + ) + + file = _make_file( + storage_path="local://uploads/test.txt", + mime_type="text/plain", + ) + response = await create_file_download_response(file) + assert response.status_code == 200 + assert response.body == b"file contents" + assert "attachment" in response.headers["Content-Disposition"] + + @pytest.mark.asyncio + async def test_local_storage_inline(self, mocker): + from backend.api.features.workspace.routes import create_file_download_response + + mock_storage = AsyncMock() + mock_storage.retrieve.return_value = b"\x89PNG" + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_storage", + return_value=mock_storage, + ) + + file = _make_file( + storage_path="local://uploads/photo.png", + mime_type="image/png", + name="photo.png", + ) + response = await create_file_download_response(file, inline=True) + assert "inline" in response.headers["Content-Disposition"] + + @pytest.mark.asyncio + async def test_gcs_redirect(self, mocker): + from backend.api.features.workspace.routes import create_file_download_response + + mock_storage = AsyncMock() + mock_storage.get_download_url.return_value = ( + "https://storage.googleapis.com/signed-url" + ) + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_storage", + return_value=mock_storage, + ) + + file = _make_file(storage_path="gcs://bucket/file.pdf") + response = await create_file_download_response(file) + assert response.status_code == 302 + assert ( + response.headers["location"] == "https://storage.googleapis.com/signed-url" + ) + + @pytest.mark.asyncio + async def test_gcs_api_fallback_streams_directly(self, mocker): + from backend.api.features.workspace.routes import create_file_download_response + + mock_storage = AsyncMock() + mock_storage.get_download_url.return_value = "/api/fallback" + mock_storage.retrieve.return_value = b"fallback content" + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_storage", + return_value=mock_storage, + ) + + file = _make_file(storage_path="gcs://bucket/file.txt") + response = await create_file_download_response(file) + assert response.status_code == 200 + assert response.body == b"fallback content" + + @pytest.mark.asyncio + async def test_gcs_signed_url_failure_falls_back_to_streaming(self, mocker): + from backend.api.features.workspace.routes import create_file_download_response + + mock_storage = AsyncMock() + mock_storage.get_download_url.side_effect = RuntimeError("GCS error") + mock_storage.retrieve.return_value = b"streamed" + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_storage", + return_value=mock_storage, + ) + + file = _make_file(storage_path="gcs://bucket/file.txt") + response = await create_file_download_response(file) + assert response.status_code == 200 + assert response.body == b"streamed" + + @pytest.mark.asyncio + async def test_gcs_total_failure_raises(self, mocker): + from backend.api.features.workspace.routes import create_file_download_response + + mock_storage = AsyncMock() + mock_storage.get_download_url.side_effect = RuntimeError("GCS error") + mock_storage.retrieve.side_effect = RuntimeError("Also failed") + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_storage", + return_value=mock_storage, + ) + + file = _make_file(storage_path="gcs://bucket/file.txt") + with pytest.raises(RuntimeError, match="Also failed"): + await create_file_download_response(file) diff --git a/autogpt_platform/backend/backend/api/rest_api.py b/autogpt_platform/backend/backend/api/rest_api.py index 2b2dba397e..abe261b725 100644 --- a/autogpt_platform/backend/backend/api/rest_api.py +++ b/autogpt_platform/backend/backend/api/rest_api.py @@ -17,6 +17,7 @@ from fastapi.routing import APIRoute from prisma.errors import PrismaError import backend.api.features.admin.credit_admin_routes +import backend.api.features.admin.diagnostics_admin_routes import backend.api.features.admin.execution_analytics_routes import backend.api.features.admin.platform_cost_routes import backend.api.features.admin.rate_limit_admin_routes @@ -31,6 +32,7 @@ import backend.api.features.library.routes import backend.api.features.mcp.routes as mcp_routes import backend.api.features.oauth import backend.api.features.otto.routes +import backend.api.features.platform_linking.routes import backend.api.features.postmark.postmark import backend.api.features.store.model import backend.api.features.store.routes @@ -320,6 +322,11 @@ app.include_router( tags=["v2", "admin"], prefix="/api/credits", ) +app.include_router( + backend.api.features.admin.diagnostics_admin_routes.router, + tags=["v2", "admin"], + prefix="/api", +) app.include_router( backend.api.features.admin.execution_analytics_routes.router, tags=["v2", "admin"], @@ -372,6 +379,11 @@ app.include_router( tags=["oauth"], prefix="/api/oauth", ) +app.include_router( + backend.api.features.platform_linking.routes.router, + tags=["platform-linking"], + prefix="/api/platform-linking", +) app.mount("/external-api", external_api) diff --git a/autogpt_platform/backend/backend/app.py b/autogpt_platform/backend/backend/app.py index 236f098761..534f385009 100644 --- a/autogpt_platform/backend/backend/app.py +++ b/autogpt_platform/backend/backend/app.py @@ -42,11 +42,13 @@ def main(**kwargs): from backend.data.db_manager import DatabaseManager from backend.executor import ExecutionManager, Scheduler from backend.notifications import NotificationManager + from backend.platform_linking.manager import PlatformLinkingManager run_processes( DatabaseManager().set_log_level("warning"), Scheduler(), NotificationManager(), + PlatformLinkingManager(), WebsocketServer(), AgentServer(), ExecutionManager(), diff --git a/autogpt_platform/backend/backend/blocks/_base.py b/autogpt_platform/backend/backend/blocks/_base.py index 2a26421c91..e2238bcae5 100644 --- a/autogpt_platform/backend/backend/blocks/_base.py +++ b/autogpt_platform/backend/backend/blocks/_base.py @@ -96,27 +96,64 @@ class BlockCategory(Enum): class BlockCostType(str, Enum): - RUN = "run" # cost X credits per run - BYTE = "byte" # cost X credits per byte - SECOND = "second" # cost X credits per second + # RUN : cost_amount credits per run. + # BYTE : cost_amount credits per byte of input data. + # SECOND : cost_amount credits per cost_divisor walltime seconds. + # ITEMS : cost_amount credits per cost_divisor items (from stats). + # COST_USD : cost_amount credits per USD of stats.provider_cost. + # TOKENS : per-(model, provider) rate table lookup; see TOKEN_COST. + RUN = "run" + BYTE = "byte" + SECOND = "second" + ITEMS = "items" + COST_USD = "cost_usd" + TOKENS = "tokens" + + @property + def is_dynamic(self) -> bool: + """Real charge is computed post-flight from stats. + + Dynamic types (SECOND/ITEMS/COST_USD/TOKENS) return 0 pre-flight and + settle against stats via charge_reconciled_usage once the block runs. + """ + return self in _DYNAMIC_COST_TYPES + + +_DYNAMIC_COST_TYPES: frozenset[BlockCostType] = frozenset( + { + BlockCostType.SECOND, + BlockCostType.ITEMS, + BlockCostType.COST_USD, + BlockCostType.TOKENS, + } +) class BlockCost(BaseModel): cost_amount: int cost_filter: BlockInput cost_type: BlockCostType + # cost_divisor: interpret cost_amount as "credits per cost_divisor units". + # Only meaningful for SECOND / ITEMS. TOKENS routes through TOKEN_COST + # rate tables (per-model input/output/cache pricing) and ignores + # cost_divisor entirely. Defaults to 1 so existing RUN/BYTE entries stay + # point-wise. Example: cost_amount=1, cost_divisor=10 under SECOND means + # "1 credit per 10 seconds of walltime". + cost_divisor: int = 1 def __init__( self, cost_amount: int, cost_type: BlockCostType = BlockCostType.RUN, cost_filter: Optional[BlockInput] = None, + cost_divisor: int = 1, **data: Any, ) -> None: super().__init__( cost_amount=cost_amount, cost_filter=cost_filter or {}, cost_type=cost_type, + cost_divisor=max(1, cost_divisor), **data, ) @@ -168,9 +205,31 @@ class BlockSchema(BaseModel): return cls.cached_jsonschema @classmethod - def validate_data(cls, data: BlockInput) -> str | None: + def validate_data( + cls, + data: BlockInput, + exclude_fields: set[str] | None = None, + ) -> str | None: + schema = cls.jsonschema() + if exclude_fields: + # Drop the excluded fields from both the properties and the + # ``required`` list so jsonschema doesn't flag them as missing. + # Used by the dry-run path to skip credentials validation while + # still validating the remaining block inputs. + schema = { + **schema, + "properties": { + k: v + for k, v in schema.get("properties", {}).items() + if k not in exclude_fields + }, + "required": [ + r for r in schema.get("required", []) if r not in exclude_fields + ], + } + data = {k: v for k, v in data.items() if k not in exclude_fields} return json.validate_with_jsonschema( - schema=cls.jsonschema(), + schema=schema, data={k: v for k, v in data.items() if v is not None}, ) @@ -311,6 +370,8 @@ class BlockSchema(BaseModel): "credentials_provider": [config.get("provider", "google")], "credentials_types": [config.get("type", "oauth2")], "credentials_scopes": config.get("scopes"), + "is_auto_credential": True, + "input_field_name": info["field_name"], } result[kwarg_name] = CredentialsFieldInfo.model_validate( auto_schema, by_alias=True @@ -421,19 +482,6 @@ class BlockWebhookConfig(BlockManualWebhookConfig): class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): _optimized_description: ClassVar[str | None] = None - def extra_runtime_cost(self, execution_stats: NodeExecutionStats) -> int: - """Return extra runtime cost to charge after this block run completes. - - Called by the executor after a block finishes with COMPLETED status. - The return value is the number of additional base-cost credits to - charge beyond the single credit already collected by charge_usage - at the start of execution. Defaults to 0 (no extra charges). - - Override in blocks (e.g. OrchestratorBlock) that make multiple LLM - calls within one run and should be billed per call. - """ - return 0 - def __init__( self, id: str = "", @@ -717,11 +765,16 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): # (e.g. AgentExecutorBlock) get proper input validation. is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False) if is_dry_run: + # Credential fields may be absent (LLM-built agents often skip + # wiring them) or nullified earlier in the pipeline. Validate + # the non-credential inputs against a schema with those fields + # excluded — stripping only the data while keeping them in the + # ``required`` list would falsely report ``'credentials' is a + # required property``. cred_field_names = set(self.input_schema.get_credentials_fields().keys()) - non_cred_data = { - k: v for k, v in input_data.items() if k not in cred_field_names - } - if error := self.input_schema.validate_data(non_cred_data): + if error := self.input_schema.validate_data( + input_data, exclude_fields=cred_field_names + ): raise BlockInputError( message=f"Unable to execute block with invalid input data: {error}", block_name=self.name, @@ -735,6 +788,61 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): block_id=self.id, ) + # Ensure auto-credential kwargs are present before we hand off to + # run(). A missing auto-credential means the upstream field (e.g. + # a Google Drive picker) didn't embed a _credentials_id, or the + # executor couldn't resolve it. Without this guard, run() would + # crash with a TypeError (missing required kwarg) or an opaque + # AttributeError deep inside the provider SDK. + # + # Only raise when the field is ALSO not populated in input_data. + # ``_acquire_auto_credentials`` intentionally skips setting the + # kwarg in two legitimate cases — ``_credentials_id`` is ``None`` + # (chained from upstream) or the field is missing from + # ``input_data`` at prep time (connected from upstream block). + # In both cases the upstream block is expected to populate the + # field value by execute time; raising here would break the + # documented ``AgentGoogleDriveFileInputBlock`` chaining pattern. + # Dry-run skips because the executor intentionally runs blocks + # without resolved creds for schema validation. + if not is_dry_run: + for ( + kwarg_name, + info, + ) in self.input_schema.get_auto_credentials_fields().items(): + kwargs.setdefault(kwarg_name, None) + if kwargs[kwarg_name] is not None: + continue + # Upstream-chained pattern: the field was populated by a + # prior node (e.g. AgentGoogleDriveFileInputBlock) whose + # output carries a resolved ``_credentials_id``. + # ``_acquire_auto_credentials`` deliberately doesn't set + # the kwarg in that case because the value isn't available + # at prep time; the executor fills it in before we reach + # ``_execute``. Trust it if the ``_credentials_id`` KEY + # is present — its value may be explicitly ``None`` in + # the chained case (see sentry thread + # PRRT_kwDOJKSTjM58sJfA). Checking truthiness here would + # falsely preempt run() for every valid chained graph + # that ships ``_credentials_id=None`` in the picker + # object. Mirror ``_acquire_auto_credentials``'s own + # skip rule, which treats ``cred_id is None`` as a + # chained-skip signal. + field_name = info["field_name"] + field_value = input_data.get(field_name) + if isinstance(field_value, dict) and "_credentials_id" in field_value: + continue + raise BlockExecutionError( + message=( + f"Missing credentials for '{kwarg_name}'. " + "Select a file via the picker (which carries " + "its credentials), or connect credentials for " + "this block." + ), + block_name=self.name, + block_id=self.id, + ) + # Use the validated input data async for output_name, output_data in self.run( self.input_schema(**{k: v for k, v in input_data.items() if v is not None}), diff --git a/autogpt_platform/backend/backend/blocks/agent.py b/autogpt_platform/backend/backend/blocks/agent.py index a4e5acff07..67eba1aa9c 100644 --- a/autogpt_platform/backend/backend/blocks/agent.py +++ b/autogpt_platform/backend/backend/blocks/agent.py @@ -171,7 +171,10 @@ class AgentExecutorBlock(Block): ) self.merge_stats( NodeExecutionStats( - extra_cost=event.stats.cost if event.stats else 0, + # Sub-graph already debited each of its own nodes; we + # roll up its total so graph_stats.cost reflects the + # full sub-graph spend. + reconciled_cost_delta=(event.stats.cost if event.stats else 0), extra_steps=event.stats.node_exec_count if event.stats else 0, ) ) diff --git a/autogpt_platform/backend/backend/blocks/agent_mail/_config.py b/autogpt_platform/backend/backend/blocks/agent_mail/_config.py index 414b19536a..cf5c0d0ff1 100644 --- a/autogpt_platform/backend/backend/blocks/agent_mail/_config.py +++ b/autogpt_platform/backend/backend/blocks/agent_mail/_config.py @@ -4,11 +4,16 @@ Shared configuration for all AgentMail blocks. from agentmail import AsyncAgentMail -from backend.sdk import APIKeyCredentials, ProviderBuilder, SecretStr +from backend.sdk import APIKeyCredentials, BlockCostType, ProviderBuilder, SecretStr +# AgentMail is in beta with no published paid tier yet, but ~37 blocks +# without any BLOCK_COSTS entry means they currently execute wallet-free. +# 1 cr/call is a conservative interim floor so no AgentMail work leaks +# past billing. Revisit once AgentMail publishes usage-based pricing. agent_mail = ( ProviderBuilder("agent_mail") .with_api_key("AGENTMAIL_API_KEY", "AgentMail API Key") + .with_base_cost(1, BlockCostType.RUN) .build() ) diff --git a/autogpt_platform/backend/backend/blocks/ayrshare/_config.py b/autogpt_platform/backend/backend/blocks/ayrshare/_config.py new file mode 100644 index 0000000000..811ce6673c --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/ayrshare/_config.py @@ -0,0 +1,21 @@ +"""Shared provider config for Ayrshare social-media blocks. + +The "credential" exposed to blocks is the **per-user Ayrshare profile key**, +not the org-level ``AYRSHARE_API_KEY``. Profile keys are provisioned per +user by :class:`~backend.integrations.managed_providers.ayrshare.AyrshareManagedProvider` +and stored in the normal credentials list with ``is_managed=True``, so every +Ayrshare block fits the standard credential flow: + + credentials: CredentialsMetaInput = ayrshare.credentials_field(...) + +``run_block`` / ``resolve_block_credentials`` take care of the rest. + +``with_managed_api_key()`` registers ``api_key`` as a supported auth type +without the env-var-backed default credential that ``with_api_key()`` would +create — the org-level ``AYRSHARE_API_KEY`` is the admin key and must never +reach a block as a "profile key". +""" + +from backend.sdk import ProviderBuilder + +ayrshare = ProviderBuilder("ayrshare").with_managed_api_key().build() diff --git a/autogpt_platform/backend/backend/blocks/ayrshare/_cost.py b/autogpt_platform/backend/backend/blocks/ayrshare/_cost.py new file mode 100644 index 0000000000..709d642b73 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/ayrshare/_cost.py @@ -0,0 +1,18 @@ +from backend.sdk import BlockCost, BlockCostType + +# Ayrshare is a subscription proxy ($149/mo Business). Per-post credit charges +# prevent a single heavy user from absorbing the fixed cost and align with the +# upload cost of each post variant. +# cost_filter matches on input_data.is_video BEFORE run() executes, so the flag +# has to be correct at input-eval time. Video-only platforms (YouTube, Snapchat) +# override the base default to True; platforms that accept both (TikTok, etc.) +# rely on the caller setting is_video explicitly for accurate billing. +# First match wins in block_usage_cost, so list the video tier first. +AYRSHARE_POST_COSTS = ( + BlockCost( + cost_amount=5, cost_type=BlockCostType.RUN, cost_filter={"is_video": True} + ), + BlockCost( + cost_amount=2, cost_type=BlockCostType.RUN, cost_filter={"is_video": False} + ), +) diff --git a/autogpt_platform/backend/backend/blocks/ayrshare/_util.py b/autogpt_platform/backend/backend/blocks/ayrshare/_util.py index 231239310f..720925e19f 100644 --- a/autogpt_platform/backend/backend/blocks/ayrshare/_util.py +++ b/autogpt_platform/backend/backend/blocks/ayrshare/_util.py @@ -4,22 +4,25 @@ from typing import Optional from pydantic import BaseModel, Field from backend.blocks._base import BlockSchemaInput -from backend.data.model import SchemaField, UserIntegrations +from backend.data.model import CredentialsMetaInput, SchemaField from backend.integrations.ayrshare import AyrshareClient -from backend.util.clients import get_database_manager_async_client from backend.util.exceptions import MissingConfigError - -async def get_profile_key(user_id: str): - user_integrations: UserIntegrations = ( - await get_database_manager_async_client().get_user_integrations(user_id) - ) - return user_integrations.managed_credentials.ayrshare_profile_key +from ._config import ayrshare class BaseAyrshareInput(BlockSchemaInput): """Base input model for Ayrshare social media posts with common fields.""" + credentials: CredentialsMetaInput = ayrshare.credentials_field( + description=( + "Ayrshare profile credential. AutoGPT provisions this managed " + "credential automatically — the user does not create it. After " + "it's in place, the user links each social account via the " + "Ayrshare SSO popup in the Builder." + ), + ) + post: str = SchemaField( description="The post text to be published", default="", advanced=False ) @@ -29,7 +32,9 @@ class BaseAyrshareInput(BlockSchemaInput): advanced=False, ) is_video: bool = SchemaField( - description="Whether the media is a video", default=False, advanced=True + description="Whether the media is a video. Set to True when uploading a video so billing applies the video tier.", + default=False, + advanced=True, ) schedule_date: Optional[datetime] = SchemaField( description="UTC datetime for scheduling (YYYY-MM-DDThh:mm:ssZ)", diff --git a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_bluesky.py b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_bluesky.py index df0d5ad269..a7254f7099 100644 --- a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_bluesky.py +++ b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_bluesky.py @@ -1,16 +1,20 @@ from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform from backend.sdk import ( + APIKeyCredentials, Block, BlockCategory, BlockOutput, BlockSchemaOutput, BlockType, SchemaField, + cost, ) -from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key +from ._cost import AYRSHARE_POST_COSTS +from ._util import BaseAyrshareInput, create_ayrshare_client +@cost(*AYRSHARE_POST_COSTS) class PostToBlueskyBlock(Block): """Block for posting to Bluesky with Bluesky-specific options.""" @@ -57,16 +61,10 @@ class PostToBlueskyBlock(Block): self, input_data: "PostToBlueskyBlock.Input", *, - user_id: str, + credentials: APIKeyCredentials, **kwargs, ) -> BlockOutput: """Post to Bluesky with Bluesky-specific options.""" - - profile_key = await get_profile_key(user_id) - if not profile_key: - yield "error", "Please link a social account via Ayrshare" - return - client = create_ayrshare_client() if not client: yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY." @@ -106,7 +104,7 @@ class PostToBlueskyBlock(Block): random_media_url=input_data.random_media_url, notes=input_data.notes, bluesky_options=bluesky_options if bluesky_options else None, - profile_key=profile_key.get_secret_value(), + profile_key=credentials.api_key.get_secret_value(), ) yield "post_result", response if response.postIds: diff --git a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_facebook.py b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_facebook.py index a9087915e6..2d4af969b1 100644 --- a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_facebook.py +++ b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_facebook.py @@ -1,21 +1,20 @@ from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform from backend.sdk import ( + APIKeyCredentials, Block, BlockCategory, BlockOutput, BlockSchemaOutput, BlockType, SchemaField, + cost, ) -from ._util import ( - BaseAyrshareInput, - CarouselItem, - create_ayrshare_client, - get_profile_key, -) +from ._cost import AYRSHARE_POST_COSTS +from ._util import BaseAyrshareInput, CarouselItem, create_ayrshare_client +@cost(*AYRSHARE_POST_COSTS) class PostToFacebookBlock(Block): """Block for posting to Facebook with Facebook-specific options.""" @@ -120,15 +119,10 @@ class PostToFacebookBlock(Block): self, input_data: "PostToFacebookBlock.Input", *, - user_id: str, + credentials: APIKeyCredentials, **kwargs, ) -> BlockOutput: """Post to Facebook with Facebook-specific options.""" - profile_key = await get_profile_key(user_id) - if not profile_key: - yield "error", "Please link a social account via Ayrshare" - return - client = create_ayrshare_client() if not client: yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY." @@ -204,7 +198,7 @@ class PostToFacebookBlock(Block): random_media_url=input_data.random_media_url, notes=input_data.notes, facebook_options=facebook_options if facebook_options else None, - profile_key=profile_key.get_secret_value(), + profile_key=credentials.api_key.get_secret_value(), ) yield "post_result", response if response.postIds: diff --git a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_gmb.py b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_gmb.py index 1f223f1f80..1856cbef65 100644 --- a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_gmb.py +++ b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_gmb.py @@ -1,16 +1,20 @@ from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform from backend.sdk import ( + APIKeyCredentials, Block, BlockCategory, BlockOutput, BlockSchemaOutput, BlockType, SchemaField, + cost, ) -from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key +from ._cost import AYRSHARE_POST_COSTS +from ._util import BaseAyrshareInput, create_ayrshare_client +@cost(*AYRSHARE_POST_COSTS) class PostToGMBBlock(Block): """Block for posting to Google My Business with GMB-specific options.""" @@ -110,14 +114,13 @@ class PostToGMBBlock(Block): ) async def run( - self, input_data: "PostToGMBBlock.Input", *, user_id: str, **kwargs + self, + input_data: "PostToGMBBlock.Input", + *, + credentials: APIKeyCredentials, + **kwargs ) -> BlockOutput: """Post to Google My Business with GMB-specific options.""" - profile_key = await get_profile_key(user_id) - if not profile_key: - yield "error", "Please link a social account via Ayrshare" - return - client = create_ayrshare_client() if not client: yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY." @@ -202,7 +205,7 @@ class PostToGMBBlock(Block): random_media_url=input_data.random_media_url, notes=input_data.notes, gmb_options=gmb_options if gmb_options else None, - profile_key=profile_key.get_secret_value(), + profile_key=credentials.api_key.get_secret_value(), ) yield "post_result", response if response.postIds: diff --git a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_instagram.py b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_instagram.py index 06d80db528..d468c1652a 100644 --- a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_instagram.py +++ b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_instagram.py @@ -2,22 +2,21 @@ from typing import Any from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform from backend.sdk import ( + APIKeyCredentials, Block, BlockCategory, BlockOutput, BlockSchemaOutput, BlockType, SchemaField, + cost, ) -from ._util import ( - BaseAyrshareInput, - InstagramUserTag, - create_ayrshare_client, - get_profile_key, -) +from ._cost import AYRSHARE_POST_COSTS +from ._util import BaseAyrshareInput, InstagramUserTag, create_ayrshare_client +@cost(*AYRSHARE_POST_COSTS) class PostToInstagramBlock(Block): """Block for posting to Instagram with Instagram-specific options.""" @@ -112,15 +111,10 @@ class PostToInstagramBlock(Block): self, input_data: "PostToInstagramBlock.Input", *, - user_id: str, + credentials: APIKeyCredentials, **kwargs, ) -> BlockOutput: """Post to Instagram with Instagram-specific options.""" - profile_key = await get_profile_key(user_id) - if not profile_key: - yield "error", "Please link a social account via Ayrshare" - return - client = create_ayrshare_client() if not client: yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY." @@ -241,7 +235,7 @@ class PostToInstagramBlock(Block): random_media_url=input_data.random_media_url, notes=input_data.notes, instagram_options=instagram_options if instagram_options else None, - profile_key=profile_key.get_secret_value(), + profile_key=credentials.api_key.get_secret_value(), ) yield "post_result", response if response.postIds: diff --git a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_linkedin.py b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_linkedin.py index 961587d201..01824cf994 100644 --- a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_linkedin.py +++ b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_linkedin.py @@ -1,16 +1,20 @@ from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform from backend.sdk import ( + APIKeyCredentials, Block, BlockCategory, BlockOutput, BlockSchemaOutput, BlockType, SchemaField, + cost, ) -from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key +from ._cost import AYRSHARE_POST_COSTS +from ._util import BaseAyrshareInput, create_ayrshare_client +@cost(*AYRSHARE_POST_COSTS) class PostToLinkedInBlock(Block): """Block for posting to LinkedIn with LinkedIn-specific options.""" @@ -112,15 +116,10 @@ class PostToLinkedInBlock(Block): self, input_data: "PostToLinkedInBlock.Input", *, - user_id: str, + credentials: APIKeyCredentials, **kwargs, ) -> BlockOutput: """Post to LinkedIn with LinkedIn-specific options.""" - profile_key = await get_profile_key(user_id) - if not profile_key: - yield "error", "Please link a social account via Ayrshare" - return - client = create_ayrshare_client() if not client: yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY." @@ -214,7 +213,7 @@ class PostToLinkedInBlock(Block): random_media_url=input_data.random_media_url, notes=input_data.notes, linkedin_options=linkedin_options if linkedin_options else None, - profile_key=profile_key.get_secret_value(), + profile_key=credentials.api_key.get_secret_value(), ) yield "post_result", response if response.postIds: diff --git a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_pinterest.py b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_pinterest.py index 834cd4e301..df8a436cbe 100644 --- a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_pinterest.py +++ b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_pinterest.py @@ -1,21 +1,20 @@ from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform from backend.sdk import ( + APIKeyCredentials, Block, BlockCategory, BlockOutput, BlockSchemaOutput, BlockType, SchemaField, + cost, ) -from ._util import ( - BaseAyrshareInput, - PinterestCarouselOption, - create_ayrshare_client, - get_profile_key, -) +from ._cost import AYRSHARE_POST_COSTS +from ._util import BaseAyrshareInput, PinterestCarouselOption, create_ayrshare_client +@cost(*AYRSHARE_POST_COSTS) class PostToPinterestBlock(Block): """Block for posting to Pinterest with Pinterest-specific options.""" @@ -92,15 +91,10 @@ class PostToPinterestBlock(Block): self, input_data: "PostToPinterestBlock.Input", *, - user_id: str, + credentials: APIKeyCredentials, **kwargs, ) -> BlockOutput: """Post to Pinterest with Pinterest-specific options.""" - profile_key = await get_profile_key(user_id) - if not profile_key: - yield "error", "Please link a social account via Ayrshare" - return - client = create_ayrshare_client() if not client: yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY." @@ -206,7 +200,7 @@ class PostToPinterestBlock(Block): random_media_url=input_data.random_media_url, notes=input_data.notes, pinterest_options=pinterest_options if pinterest_options else None, - profile_key=profile_key.get_secret_value(), + profile_key=credentials.api_key.get_secret_value(), ) yield "post_result", response if response.postIds: diff --git a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_reddit.py b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_reddit.py index 1df721f424..40fbe14cd1 100644 --- a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_reddit.py +++ b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_reddit.py @@ -1,16 +1,20 @@ from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform from backend.sdk import ( + APIKeyCredentials, Block, BlockCategory, BlockOutput, BlockSchemaOutput, BlockType, SchemaField, + cost, ) -from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key +from ._cost import AYRSHARE_POST_COSTS +from ._util import BaseAyrshareInput, create_ayrshare_client +@cost(*AYRSHARE_POST_COSTS) class PostToRedditBlock(Block): """Block for posting to Reddit.""" @@ -35,12 +39,12 @@ class PostToRedditBlock(Block): ) async def run( - self, input_data: "PostToRedditBlock.Input", *, user_id: str, **kwargs + self, + input_data: "PostToRedditBlock.Input", + *, + credentials: APIKeyCredentials, + **kwargs ) -> BlockOutput: - profile_key = await get_profile_key(user_id) - if not profile_key: - yield "error", "Please link a social account via Ayrshare" - return client = create_ayrshare_client() if not client: yield "error", "Ayrshare integration is not configured." @@ -61,7 +65,7 @@ class PostToRedditBlock(Block): random_post=input_data.random_post, random_media_url=input_data.random_media_url, notes=input_data.notes, - profile_key=profile_key.get_secret_value(), + profile_key=credentials.api_key.get_secret_value(), ) yield "post_result", response if response.postIds: diff --git a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_snapchat.py b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_snapchat.py index 3645f7cc9b..996518dacf 100644 --- a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_snapchat.py +++ b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_snapchat.py @@ -1,16 +1,20 @@ from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform from backend.sdk import ( + APIKeyCredentials, Block, BlockCategory, BlockOutput, BlockSchemaOutput, BlockType, SchemaField, + cost, ) -from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key +from ._cost import AYRSHARE_POST_COSTS +from ._util import BaseAyrshareInput, create_ayrshare_client +@cost(*AYRSHARE_POST_COSTS) class PostToSnapchatBlock(Block): """Block for posting to Snapchat with Snapchat-specific options.""" @@ -31,6 +35,14 @@ class PostToSnapchatBlock(Block): advanced=False, ) + # Snapchat is video-only; override the base default so the @cost filter + # selects the 5-credit video tier instead of the 2-credit image tier. + is_video: bool = SchemaField( + description="Whether the media is a video (always True for Snapchat)", + default=True, + advanced=True, + ) + # Snapchat-specific options story_type: str = SchemaField( description="Type of Snapchat content: 'story' (24-hour Stories), 'saved_story' (Saved Stories), or 'spotlight' (Spotlight posts)", @@ -62,15 +74,10 @@ class PostToSnapchatBlock(Block): self, input_data: "PostToSnapchatBlock.Input", *, - user_id: str, + credentials: APIKeyCredentials, **kwargs, ) -> BlockOutput: """Post to Snapchat with Snapchat-specific options.""" - profile_key = await get_profile_key(user_id) - if not profile_key: - yield "error", "Please link a social account via Ayrshare" - return - client = create_ayrshare_client() if not client: yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY." @@ -121,7 +128,7 @@ class PostToSnapchatBlock(Block): random_media_url=input_data.random_media_url, notes=input_data.notes, snapchat_options=snapchat_options if snapchat_options else None, - profile_key=profile_key.get_secret_value(), + profile_key=credentials.api_key.get_secret_value(), ) yield "post_result", response if response.postIds: diff --git a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_telegram.py b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_telegram.py index a220cbe9e8..f526c6ea9f 100644 --- a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_telegram.py +++ b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_telegram.py @@ -1,16 +1,20 @@ from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform from backend.sdk import ( + APIKeyCredentials, Block, BlockCategory, BlockOutput, BlockSchemaOutput, BlockType, SchemaField, + cost, ) -from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key +from ._cost import AYRSHARE_POST_COSTS +from ._util import BaseAyrshareInput, create_ayrshare_client +@cost(*AYRSHARE_POST_COSTS) class PostToTelegramBlock(Block): """Block for posting to Telegram with Telegram-specific options.""" @@ -57,15 +61,10 @@ class PostToTelegramBlock(Block): self, input_data: "PostToTelegramBlock.Input", *, - user_id: str, + credentials: APIKeyCredentials, **kwargs, ) -> BlockOutput: """Post to Telegram with Telegram-specific validation.""" - profile_key = await get_profile_key(user_id) - if not profile_key: - yield "error", "Please link a social account via Ayrshare" - return - client = create_ayrshare_client() if not client: yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY." @@ -108,7 +107,7 @@ class PostToTelegramBlock(Block): random_post=input_data.random_post, random_media_url=input_data.random_media_url, notes=input_data.notes, - profile_key=profile_key.get_secret_value(), + profile_key=credentials.api_key.get_secret_value(), ) yield "post_result", response if response.postIds: diff --git a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_threads.py b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_threads.py index 75983b2d13..ebafc28308 100644 --- a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_threads.py +++ b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_threads.py @@ -1,16 +1,20 @@ from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform from backend.sdk import ( + APIKeyCredentials, Block, BlockCategory, BlockOutput, BlockSchemaOutput, BlockType, SchemaField, + cost, ) -from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key +from ._cost import AYRSHARE_POST_COSTS +from ._util import BaseAyrshareInput, create_ayrshare_client +@cost(*AYRSHARE_POST_COSTS) class PostToThreadsBlock(Block): """Block for posting to Threads with Threads-specific options.""" @@ -50,15 +54,10 @@ class PostToThreadsBlock(Block): self, input_data: "PostToThreadsBlock.Input", *, - user_id: str, + credentials: APIKeyCredentials, **kwargs, ) -> BlockOutput: """Post to Threads with Threads-specific validation.""" - profile_key = await get_profile_key(user_id) - if not profile_key: - yield "error", "Please link a social account via Ayrshare" - return - client = create_ayrshare_client() if not client: yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY." @@ -103,7 +102,7 @@ class PostToThreadsBlock(Block): random_media_url=input_data.random_media_url, notes=input_data.notes, threads_options=threads_options if threads_options else None, - profile_key=profile_key.get_secret_value(), + profile_key=credentials.api_key.get_secret_value(), ) yield "post_result", response if response.postIds: diff --git a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_tiktok.py b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_tiktok.py index 2d68f10ff0..5b731dcc8b 100644 --- a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_tiktok.py +++ b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_tiktok.py @@ -2,15 +2,18 @@ from enum import Enum from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform from backend.sdk import ( + APIKeyCredentials, Block, BlockCategory, BlockOutput, BlockSchemaOutput, BlockType, SchemaField, + cost, ) -from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key +from ._cost import AYRSHARE_POST_COSTS +from ._util import BaseAyrshareInput, create_ayrshare_client class TikTokVisibility(str, Enum): @@ -19,6 +22,7 @@ class TikTokVisibility(str, Enum): FOLLOWERS = "followers" +@cost(*AYRSHARE_POST_COSTS) class PostToTikTokBlock(Block): """Block for posting to TikTok with TikTok-specific options.""" @@ -113,14 +117,13 @@ class PostToTikTokBlock(Block): ) async def run( - self, input_data: "PostToTikTokBlock.Input", *, user_id: str, **kwargs + self, + input_data: "PostToTikTokBlock.Input", + *, + credentials: APIKeyCredentials, + **kwargs, ) -> BlockOutput: """Post to TikTok with TikTok-specific validation and options.""" - profile_key = await get_profile_key(user_id) - if not profile_key: - yield "error", "Please link a social account via Ayrshare" - return - client = create_ayrshare_client() if not client: yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY." @@ -235,7 +238,7 @@ class PostToTikTokBlock(Block): random_media_url=input_data.random_media_url, notes=input_data.notes, tiktok_options=tiktok_options if tiktok_options else None, - profile_key=profile_key.get_secret_value(), + profile_key=credentials.api_key.get_secret_value(), ) yield "post_result", response if response.postIds: diff --git a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_x.py b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_x.py index bbecd31ed4..da1fe48c26 100644 --- a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_x.py +++ b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_x.py @@ -1,16 +1,20 @@ from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform from backend.sdk import ( + APIKeyCredentials, Block, BlockCategory, BlockOutput, BlockSchemaOutput, BlockType, SchemaField, + cost, ) -from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key +from ._cost import AYRSHARE_POST_COSTS +from ._util import BaseAyrshareInput, create_ayrshare_client +@cost(*AYRSHARE_POST_COSTS) class PostToXBlock(Block): """Block for posting to X / Twitter with Twitter-specific options.""" @@ -115,15 +119,10 @@ class PostToXBlock(Block): self, input_data: "PostToXBlock.Input", *, - user_id: str, + credentials: APIKeyCredentials, **kwargs, ) -> BlockOutput: """Post to X / Twitter with enhanced X-specific options.""" - profile_key = await get_profile_key(user_id) - if not profile_key: - yield "error", "Please link a social account via Ayrshare" - return - client = create_ayrshare_client() if not client: yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY." @@ -233,7 +232,7 @@ class PostToXBlock(Block): random_media_url=input_data.random_media_url, notes=input_data.notes, twitter_options=twitter_options if twitter_options else None, - profile_key=profile_key.get_secret_value(), + profile_key=credentials.api_key.get_secret_value(), ) yield "post_result", response if response.postIds: diff --git a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_youtube.py b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_youtube.py index 8a366ba5c5..021f4c1005 100644 --- a/autogpt_platform/backend/backend/blocks/ayrshare/post_to_youtube.py +++ b/autogpt_platform/backend/backend/blocks/ayrshare/post_to_youtube.py @@ -3,15 +3,18 @@ from typing import Any from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform from backend.sdk import ( + APIKeyCredentials, Block, BlockCategory, BlockOutput, BlockSchemaOutput, BlockType, SchemaField, + cost, ) -from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key +from ._cost import AYRSHARE_POST_COSTS +from ._util import BaseAyrshareInput, create_ayrshare_client class YouTubeVisibility(str, Enum): @@ -20,6 +23,7 @@ class YouTubeVisibility(str, Enum): UNLISTED = "unlisted" +@cost(*AYRSHARE_POST_COSTS) class PostToYouTubeBlock(Block): """Block for posting to YouTube with YouTube-specific options.""" @@ -39,6 +43,14 @@ class PostToYouTubeBlock(Block): advanced=False, ) + # YouTube is video-only; override the base default so the @cost filter + # selects the 5-credit video tier instead of the 2-credit image tier. + is_video: bool = SchemaField( + description="Whether the media is a video (always True for YouTube)", + default=True, + advanced=True, + ) + # YouTube-specific required options title: str = SchemaField( description="Video title (max 100 chars, required). Cannot contain < or > characters.", @@ -137,16 +149,10 @@ class PostToYouTubeBlock(Block): self, input_data: "PostToYouTubeBlock.Input", *, - user_id: str, + credentials: APIKeyCredentials, **kwargs, ) -> BlockOutput: """Post to YouTube with YouTube-specific validation and options.""" - - profile_key = await get_profile_key(user_id) - if not profile_key: - yield "error", "Please link a social account via Ayrshare" - return - client = create_ayrshare_client() if not client: yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY." @@ -302,7 +308,7 @@ class PostToYouTubeBlock(Block): random_media_url=input_data.random_media_url, notes=input_data.notes, youtube_options=youtube_options, - profile_key=profile_key.get_secret_value(), + profile_key=credentials.api_key.get_secret_value(), ) yield "post_result", response if response.postIds: diff --git a/autogpt_platform/backend/backend/blocks/baas/bots.py b/autogpt_platform/backend/backend/blocks/baas/bots.py index 68af9a675e..5548074453 100644 --- a/autogpt_platform/backend/backend/blocks/baas/bots.py +++ b/autogpt_platform/backend/backend/blocks/baas/bots.py @@ -4,21 +4,34 @@ Meeting BaaS bot (recording) blocks. from typing import Optional +from backend.data.model import NodeExecutionStats from backend.sdk import ( APIKeyCredentials, Block, BlockCategory, + BlockCost, + BlockCostType, BlockOutput, BlockSchemaInput, BlockSchemaOutput, CredentialsMetaInput, SchemaField, + cost, ) from ._api import MeetingBaasAPI from ._config import baas +# Meeting BaaS recording rate: $0.69 per hour. +_MEETING_BAAS_USD_PER_SECOND = 0.69 / 3600 +# Join bills a flat 30 cr commit (covers median short meeting); +# FetchMeetingData bills the duration-scaled remainder from the +# `duration_seconds` field on the API response. Long meetings no +# longer under-bill. + + +@cost(BlockCost(cost_type=BlockCostType.RUN, cost_amount=30)) class BaasBotJoinMeetingBlock(Block): """ Deploy a bot immediately or at a scheduled start_time to join and record a meeting. @@ -134,6 +147,7 @@ class BaasBotLeaveMeetingBlock(Block): yield "left", left +@cost(BlockCost(cost_type=BlockCostType.COST_USD, cost_amount=150)) class BaasBotFetchMeetingDataBlock(Block): """ Pull MP4 URL, transcript & metadata for a completed meeting. @@ -176,9 +190,21 @@ class BaasBotFetchMeetingDataBlock(Block): include_transcripts=input_data.include_transcripts, ) + bot_meta = data.get("bot_data", {}).get("bot", {}) or {} + # Bill recording duration via COST_USD so multi-hour meetings + # scale past the Join block's flat 30 cr deposit. + duration_seconds = float(bot_meta.get("duration_seconds") or 0) + if duration_seconds > 0: + self.merge_stats( + NodeExecutionStats( + provider_cost=duration_seconds * _MEETING_BAAS_USD_PER_SECOND, + provider_cost_type="cost_usd", + ) + ) + yield "mp4_url", data.get("mp4", "") yield "transcript", data.get("bot_data", {}).get("transcripts", []) - yield "metadata", data.get("bot_data", {}).get("bot", {}) + yield "metadata", bot_meta class BaasBotDeleteRecordingBlock(Block): diff --git a/autogpt_platform/backend/backend/blocks/baas/bots_cost_test.py b/autogpt_platform/backend/backend/blocks/baas/bots_cost_test.py new file mode 100644 index 0000000000..aea4f65620 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/baas/bots_cost_test.py @@ -0,0 +1,86 @@ +"""Unit tests for Meeting BaaS duration-based cost emission.""" + +from unittest.mock import AsyncMock, patch + +import pytest +from pydantic import SecretStr + +from backend.blocks.baas.bots import ( + _MEETING_BAAS_USD_PER_SECOND, + BaasBotFetchMeetingDataBlock, +) +from backend.data.model import APIKeyCredentials, NodeExecutionStats + +TEST_CREDENTIALS = APIKeyCredentials( + id="01234567-89ab-cdef-0123-456789abcdef", + provider="baas", + title="Mock BaaS API Key", + api_key=SecretStr("mock-baas-api-key"), + expires_at=None, +) + + +def test_usd_per_second_derives_from_published_rate(): + """$0.69/hour published rate → ~$0.000192/second.""" + assert _MEETING_BAAS_USD_PER_SECOND == pytest.approx(0.69 / 3600) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "duration_seconds, expected_usd", + [ + (3600, 0.69), # 1 hour + (1800, 0.345), # 30 min + (0, None), # no recording → no emission + (None, None), # missing duration field → no emission + ], +) +async def test_fetch_meeting_data_emits_duration_cost_usd( + duration_seconds, expected_usd +): + """FetchMeetingData extracts duration_seconds from bot metadata and + emits provider_cost / cost_usd scaled by the published $0.69/hr rate. + Emission is skipped when duration is 0 or missing. + """ + block = BaasBotFetchMeetingDataBlock() + + bot_meta = {"id": "bot-xyz"} + if duration_seconds is not None: + bot_meta["duration_seconds"] = duration_seconds + + mock_api = AsyncMock() + mock_api.get_meeting_data.return_value = { + "mp4": "https://example/recording.mp4", + "bot_data": {"bot": bot_meta, "transcripts": []}, + } + + captured: list[NodeExecutionStats] = [] + with ( + patch("backend.blocks.baas.bots.MeetingBaasAPI", return_value=mock_api), + patch.object(block, "merge_stats", side_effect=captured.append), + ): + outputs = [] + async for name, val in block.run( + block.input_schema( + credentials={ + "id": TEST_CREDENTIALS.id, + "provider": TEST_CREDENTIALS.provider, + "type": TEST_CREDENTIALS.type, + }, + bot_id="bot-xyz", + include_transcripts=False, + ), + credentials=TEST_CREDENTIALS, + ): + outputs.append((name, val)) + + # Always yields the 3 outputs regardless of duration. + names = [n for n, _ in outputs] + assert "mp4_url" in names and "metadata" in names + + if expected_usd is None: + assert captured == [] + else: + assert len(captured) == 1 + assert captured[0].provider_cost == pytest.approx(expected_usd) + assert captured[0].provider_cost_type == "cost_usd" diff --git a/autogpt_platform/backend/backend/blocks/bannerbear/_config.py b/autogpt_platform/backend/backend/blocks/bannerbear/_config.py index 0303f49ca2..32fe7fff21 100644 --- a/autogpt_platform/backend/backend/blocks/bannerbear/_config.py +++ b/autogpt_platform/backend/backend/blocks/bannerbear/_config.py @@ -3,6 +3,6 @@ from backend.sdk import BlockCostType, ProviderBuilder bannerbear = ( ProviderBuilder("bannerbear") .with_api_key("BANNERBEAR_API_KEY", "Bannerbear API Key") - .with_base_cost(1, BlockCostType.RUN) + .with_base_cost(3, BlockCostType.RUN) .build() ) diff --git a/autogpt_platform/backend/backend/blocks/claude_code.py b/autogpt_platform/backend/backend/blocks/claude_code.py index 2e870f02b6..03c8f70312 100644 --- a/autogpt_platform/backend/backend/blocks/claude_code.py +++ b/autogpt_platform/backend/backend/blocks/claude_code.py @@ -17,6 +17,7 @@ from backend.data.model import ( APIKeyCredentials, CredentialsField, CredentialsMetaInput, + NodeExecutionStats, SchemaField, ) from backend.integrations.providers import ProviderName @@ -431,6 +432,7 @@ class ClaudeCodeBlock(Block): # The JSON output contains the result output_data = json.loads(raw_output) response = output_data.get("result", raw_output) + self._record_cli_cost(output_data) # Build conversation history entry turn_entry = f"User: {prompt}\nClaude: {response}" @@ -484,6 +486,23 @@ class ClaudeCodeBlock(Block): escaped = prompt.replace("'", "'\"'\"'") return f"'{escaped}'" + def _record_cli_cost(self, output_data: dict) -> None: + """Feed Claude Code CLI's `total_cost_usd` to the COST_USD resolver. + + The CLI rolls up Anthropic LLM + internal tool-call spend into + ``total_cost_usd`` on its JSON response; piping it through + ``merge_stats`` lets the wallet reflect real spend. + """ + total_cost_usd = output_data.get("total_cost_usd") + if total_cost_usd is None: + return + self.merge_stats( + NodeExecutionStats( + provider_cost=float(total_cost_usd), + provider_cost_type="cost_usd", + ) + ) + async def run( self, input_data: Input, diff --git a/autogpt_platform/backend/backend/blocks/claude_code_cost_test.py b/autogpt_platform/backend/backend/blocks/claude_code_cost_test.py new file mode 100644 index 0000000000..1e51d72f42 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/claude_code_cost_test.py @@ -0,0 +1,106 @@ +"""Unit tests for ClaudeCodeBlock COST_USD billing migration. + +Verifies: +- Block emits provider_cost / cost_usd when Claude Code CLI returns + total_cost_usd. +- block_usage_cost resolves the COST_USD entry to the expected ceil(usd * + cost_amount) credit charge. +- Missing total_cost_usd gracefully produces provider_cost=None (no bill). +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from backend.blocks._base import BlockCostType +from backend.blocks.claude_code import ClaudeCodeBlock +from backend.data.block_cost_config import BLOCK_COSTS +from backend.data.model import NodeExecutionStats +from backend.executor.utils import block_usage_cost + + +def test_claude_code_registered_as_cost_usd_150(): + """Sanity: BLOCK_COSTS holds the COST_USD, 150 cr/$ entry.""" + entries = BLOCK_COSTS[ClaudeCodeBlock] + assert len(entries) == 1 + entry = entries[0] + assert entry.cost_type == BlockCostType.COST_USD + assert entry.cost_amount == 150 + + +@pytest.mark.parametrize( + "total_cost_usd, expected_credits", + [ + (0.50, 75), # $0.50 × 150 = 75 cr + (1.00, 150), # $1.00 × 150 = 150 cr + (0.0134, 3), # ceil(0.0134 × 150) = ceil(2.01) = 3 + (2.00, 300), # $2 × 150 = 300 cr + (0.001, 1), # ceil(0.001 × 150) = ceil(0.15) = 1 — no 0-cr leak on + # sub-cent runs + ], +) +def test_cost_usd_resolver_applies_150_multiplier(total_cost_usd, expected_credits): + """block_usage_cost with cost_usd stats returns ceil(usd * 150).""" + block = ClaudeCodeBlock() + # cost_filter requires matching e2b_credentials; supply the ones the + # registration uses so _is_cost_filter_match accepts the input. + entry = BLOCK_COSTS[ClaudeCodeBlock][0] + input_data = {"e2b_credentials": entry.cost_filter["e2b_credentials"]} + stats = NodeExecutionStats( + provider_cost=total_cost_usd, + provider_cost_type="cost_usd", + ) + cost, matching_filter = block_usage_cost( + block=block, input_data=input_data, stats=stats + ) + assert cost == expected_credits + assert matching_filter == entry.cost_filter + + +def test_cost_usd_resolver_returns_zero_when_stats_missing_cost(): + """Pre-flight (no stats) or unbilled run (provider_cost None) → 0.""" + block = ClaudeCodeBlock() + entry = BLOCK_COSTS[ClaudeCodeBlock][0] + input_data = {"e2b_credentials": entry.cost_filter["e2b_credentials"]} + # No stats at all → pre-flight path, returns 0. + pre_cost, _ = block_usage_cost(block=block, input_data=input_data) + assert pre_cost == 0 + # Stats present but no provider_cost → resolver can't bill. + stats = NodeExecutionStats() + post_cost, _ = block_usage_cost(block=block, input_data=input_data, stats=stats) + assert post_cost == 0 + + +def test_record_cli_cost_emits_provider_cost_when_total_cost_present(): + """``_record_cli_cost`` (the helper called from ``execute_claude_code``) + must emit a single ``merge_stats`` with provider_cost + cost_usd tag + when the CLI JSON payload carries ``total_cost_usd``. + """ + block = ClaudeCodeBlock() + captured: list[NodeExecutionStats] = [] + with patch.object(block, "merge_stats", side_effect=captured.append): + block._record_cli_cost( + { + "result": "hello from claude", + "total_cost_usd": 0.0421, + "usage": {"input_tokens": 1234, "output_tokens": 56}, + } + ) + + assert len(captured) == 1 + stats = captured[0] + assert stats.provider_cost == pytest.approx(0.0421) + assert stats.provider_cost_type == "cost_usd" + + +def test_record_cli_cost_skips_merge_when_total_cost_absent(): + """If the CLI payload lacks ``total_cost_usd`` (legacy / non-JSON + output), ``_record_cli_cost`` must not call ``merge_stats`` — otherwise + we'd pollute telemetry with a ``cost_usd`` emission that has no real + cost attached. + """ + block = ClaudeCodeBlock() + mock = MagicMock() + with patch.object(block, "merge_stats", mock): + block._record_cli_cost({"result": "hello"}) + mock.assert_not_called() diff --git a/autogpt_platform/backend/backend/blocks/codex.py b/autogpt_platform/backend/backend/blocks/codex.py index 07dffec39f..0ff3eb4bc0 100644 --- a/autogpt_platform/backend/backend/blocks/codex.py +++ b/autogpt_platform/backend/backend/blocks/codex.py @@ -151,6 +151,17 @@ class CodeGenerationBlock(Block): ) self.execution_stats = NodeExecutionStats() + # GPT-5.1-Codex published pricing: $1.25 / 1M input, $10 / 1M output. + _INPUT_USD_PER_1M = 1.25 + _OUTPUT_USD_PER_1M = 10.0 + + @staticmethod + def _compute_token_usd(input_tokens: int, output_tokens: int) -> float: + return ( + input_tokens * CodeGenerationBlock._INPUT_USD_PER_1M + + output_tokens * CodeGenerationBlock._OUTPUT_USD_PER_1M + ) / 1_000_000 + async def call_codex( self, *, @@ -189,13 +200,15 @@ class CodeGenerationBlock(Block): response_id = response.id or "" # Update usage stats - self.execution_stats.input_token_count = ( - response.usage.input_tokens if response.usage else 0 - ) - self.execution_stats.output_token_count = ( - response.usage.output_tokens if response.usage else 0 - ) + input_tokens = response.usage.input_tokens if response.usage else 0 + output_tokens = response.usage.output_tokens if response.usage else 0 + self.execution_stats.input_token_count = input_tokens + self.execution_stats.output_token_count = output_tokens self.execution_stats.llm_call_count += 1 + self.execution_stats.provider_cost = self._compute_token_usd( + input_tokens, output_tokens + ) + self.execution_stats.provider_cost_type = "cost_usd" return CodexCallResult( response=text_output, diff --git a/autogpt_platform/backend/backend/blocks/cost_leak_fixes_test.py b/autogpt_platform/backend/backend/blocks/cost_leak_fixes_test.py new file mode 100644 index 0000000000..5f647466de --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/cost_leak_fixes_test.py @@ -0,0 +1,226 @@ +"""Coverage tests for the cost-leak fixes in this PR. + +Each block's ``run()`` / helper emits provider_cost + cost_usd (or items) +via merge_stats so the post-flight resolver bills real provider spend. +Tests here drive that emission path directly so a regression on any one +block surfaces immediately. +""" + +from unittest.mock import patch + +import pytest +from pydantic import SecretStr + +from backend.blocks._base import BlockCostType +from backend.blocks.ai_condition import AIConditionBlock +from backend.data.block_cost_config import BLOCK_COSTS, LLM_COST +from backend.data.model import APIKeyCredentials, NodeExecutionStats + +# -------- AIConditionBlock registration -------- + + +def test_ai_condition_registered_under_llm_cost(): + """AIConditionBlock was running wallet-free before this PR; verify it + now resolves through the same per-model LLM_COST table as every other + LLM block. + """ + assert BLOCK_COSTS[AIConditionBlock] is LLM_COST + + +# -------- Pinecone insert ITEMS emission -------- + + +@pytest.mark.asyncio +async def test_pinecone_insert_emits_items_provider_cost(): + from backend.blocks.pinecone import PineconeInsertBlock + + block = PineconeInsertBlock() + captured: list[NodeExecutionStats] = [] + + class _FakeIndex: + def upsert(self, **_): + return None + + class _FakePinecone: + def __init__(self, *_, **__): + pass + + def Index(self, _name): + return _FakeIndex() + + with ( + patch("backend.blocks.pinecone.Pinecone", _FakePinecone), + patch.object(block, "merge_stats", side_effect=captured.append), + ): + input_data = block.input_schema( + credentials={ + "id": "00000000-0000-0000-0000-000000000000", + "provider": "pinecone", + "type": "api_key", + }, + index="my-index", + chunks=["alpha", "beta", "gamma"], + embeddings=[[0.1] * 4, [0.2] * 4, [0.3] * 4], + namespace="", + metadata={}, + ) + + creds = APIKeyCredentials( + id="00000000-0000-0000-0000-000000000000", + provider="pinecone", + title="mock", + api_key=SecretStr("mock-key"), + expires_at=None, + ) + outputs = [(n, v) async for n, v in block.run(input_data, credentials=creds)] + + assert any(name == "upsert_response" for name, _ in outputs) + assert len(captured) == 1 + stats = captured[0] + assert stats.provider_cost == pytest.approx(3.0) + assert stats.provider_cost_type == "items" + + +# -------- Narration model-aware per-char rate -------- + + +@pytest.mark.parametrize( + "model_id, expected_rate_per_char", + [ + ("eleven_flash_v2_5", 0.000167 * 0.5), + ("eleven_turbo_v2_5", 0.000167 * 0.5), + ("eleven_multilingual_v2", 0.000167 * 1.0), + ("eleven_turbo_v2", 0.000167 * 1.0), + ], +) +def test_narration_per_char_rate_scales_with_model(model_id, expected_rate_per_char): + """Drive VideoNarrationBlock._record_script_cost directly so a regression + that drops the model-aware branching (e.g. hardcoding 1.0 cr/char for + all models) makes this test fail. + """ + from backend.blocks.video.narration import VideoNarrationBlock + + block = VideoNarrationBlock() + captured: list[NodeExecutionStats] = [] + with patch.object(block, "merge_stats", side_effect=captured.append): + block._record_script_cost("x" * 5000, model_id) + + assert len(captured) == 1 + stats = captured[0] + assert stats.provider_cost == pytest.approx(5000 * expected_rate_per_char) + assert stats.provider_cost_type == "cost_usd" + + +# -------- Perplexity None-guard on x-total-cost -------- + + +@pytest.mark.parametrize( + "openrouter_cost, expect_type", + [ + (0.0421, "cost_usd"), # concrete positive USD → tagged + (None, None), # header missing → no tag (keeps gap observable) + (0.0, None), # zero → no tag (wouldn't bill anything anyway) + ], +) +def test_perplexity_record_openrouter_cost_tags_only_on_concrete_value( + openrouter_cost, expect_type +): + """Drive PerplexityBlock._record_openrouter_cost directly to verify the + None/0 guard. A regression that tags cost_usd unconditionally would + silently floor the user's bill to 0 via the resolver — this test + would catch it. + """ + from backend.blocks.perplexity import PerplexityBlock + + block = PerplexityBlock() + with patch( + "backend.blocks.perplexity.extract_openrouter_cost", + return_value=openrouter_cost, + ): + block._record_openrouter_cost(response=object()) + + assert block.execution_stats.provider_cost == openrouter_cost + assert block.execution_stats.provider_cost_type == expect_type + + +# -------- Codex COST_USD registration -------- + + +def test_codex_registered_as_cost_usd_150(): + from backend.blocks.codex import CodeGenerationBlock + + entries = BLOCK_COSTS[CodeGenerationBlock] + assert len(entries) == 1 + entry = entries[0] + assert entry.cost_type == BlockCostType.COST_USD + assert entry.cost_amount == 150 + + +@pytest.mark.parametrize( + "input_tokens, output_tokens, expected_usd", + [ + # GPT-5.1-Codex: $1.25 / 1M input, $10 / 1M output. + (1_000_000, 0, 1.25), + (0, 1_000_000, 10.0), + (100_000, 10_000, 0.225), # 0.125 + 0.100 + (0, 0, 0.0), + ], +) +def test_codex_computes_provider_cost_usd_from_token_counts( + input_tokens, output_tokens, expected_usd +): + """Drive CodeGenerationBlock._compute_token_usd directly. A regression + to the wrong rate constants (e.g. swapping the $1.25 input rate for + GPT-4o's $2.50) would fail this test. + """ + from backend.blocks.codex import CodeGenerationBlock + + assert CodeGenerationBlock._compute_token_usd( + input_tokens, output_tokens + ) == pytest.approx(expected_usd) + + +# -------- ClaudeCode COST_USD registration sanity (already tested in claude_code_cost_test.py) -------- + + +# -------- Perplexity COST_USD registration for all 3 tiers -------- + + +def test_perplexity_sonar_all_tiers_registered_as_cost_usd_150(): + from backend.blocks.perplexity import PerplexityBlock + + entries = BLOCK_COSTS[PerplexityBlock] + # 3 tiers (SONAR, SONAR_PRO, SONAR_DEEP_RESEARCH) all COST_USD 150. + assert len(entries) == 3 + for entry in entries: + assert entry.cost_type == BlockCostType.COST_USD + assert entry.cost_amount == 150 + + +# -------- Narration COST_USD registration -------- + + +def test_narration_registered_as_cost_usd_150(): + from backend.blocks.video.narration import VideoNarrationBlock + + entries = BLOCK_COSTS[VideoNarrationBlock] + assert len(entries) == 1 + assert entries[0].cost_type == BlockCostType.COST_USD + assert entries[0].cost_amount == 150 + + +# -------- Pinecone registrations -------- + + +def test_pinecone_registrations(): + from backend.blocks.pinecone import ( + PineconeInitBlock, + PineconeInsertBlock, + PineconeQueryBlock, + ) + + assert BLOCK_COSTS[PineconeInitBlock][0].cost_type == BlockCostType.RUN + assert BLOCK_COSTS[PineconeQueryBlock][0].cost_type == BlockCostType.RUN + # Insert scales with item count. + assert BLOCK_COSTS[PineconeInsertBlock][0].cost_type == BlockCostType.ITEMS + assert BLOCK_COSTS[PineconeInsertBlock][0].cost_amount == 1 diff --git a/autogpt_platform/backend/backend/blocks/dataforseo/_api.py b/autogpt_platform/backend/backend/blocks/dataforseo/_api.py index 3b3190e66d..b4a30dda0d 100644 --- a/autogpt_platform/backend/backend/blocks/dataforseo/_api.py +++ b/autogpt_platform/backend/backend/blocks/dataforseo/_api.py @@ -19,6 +19,10 @@ class DataForSeoClient: trusted_origins=["https://api.dataforseo.com"], raise_for_status=False, ) + # USD cost reported by DataForSEO on the most recent successful call. + # Populated by keyword_suggestions / related_keywords so the caller + # can surface it via NodeExecutionStats.provider_cost for billing. + self.last_cost_usd: float = 0.0 def _get_headers(self) -> Dict[str, str]: """Generate the authorization header using Basic Auth.""" @@ -97,6 +101,9 @@ class DataForSeoClient: if data.get("tasks") and len(data["tasks"]) > 0: task = data["tasks"][0] if task.get("status_code") == 20000: # Success code + # DataForSEO reports per-task USD cost; stash it so callers + # can populate NodeExecutionStats.provider_cost. + self.last_cost_usd = float(task.get("cost") or 0.0) return task.get("result", []) else: error_msg = task.get("status_message", "Task failed") @@ -174,6 +181,9 @@ class DataForSeoClient: if data.get("tasks") and len(data["tasks"]) > 0: task = data["tasks"][0] if task.get("status_code") == 20000: # Success code + # DataForSEO reports per-task USD cost; stash it so callers + # can populate NodeExecutionStats.provider_cost. + self.last_cost_usd = float(task.get("cost") or 0.0) return task.get("result", []) else: error_msg = task.get("status_message", "Task failed") diff --git a/autogpt_platform/backend/backend/blocks/dataforseo/_config.py b/autogpt_platform/backend/backend/blocks/dataforseo/_config.py index 10b2b91130..ec979de893 100644 --- a/autogpt_platform/backend/backend/blocks/dataforseo/_config.py +++ b/autogpt_platform/backend/backend/blocks/dataforseo/_config.py @@ -12,6 +12,11 @@ dataforseo = ( password_env_var="DATAFORSEO_PASSWORD", title="DataForSEO Credentials", ) - .with_base_cost(1, BlockCostType.RUN) + # DataForSEO reports USD cost per task (e.g. $0.001/keyword returned). + # DataForSeoClient stashes it on last_cost_usd; each block emits it via + # merge_stats so the COST_USD resolver bills against real spend. + # 1000 platform credits per USD → 1 credit per $0.001 (≈ 1 credit/ + # returned keyword on the standard tier). + .with_base_cost(1000, BlockCostType.COST_USD) .build() ) diff --git a/autogpt_platform/backend/backend/blocks/dataforseo/keyword_suggestions.py b/autogpt_platform/backend/backend/blocks/dataforseo/keyword_suggestions.py index a1ecc86386..1c546615f7 100644 --- a/autogpt_platform/backend/backend/blocks/dataforseo/keyword_suggestions.py +++ b/autogpt_platform/backend/backend/blocks/dataforseo/keyword_suggestions.py @@ -4,6 +4,7 @@ DataForSEO Google Keyword Suggestions block. from typing import Any, Dict, List, Optional +from backend.data.model import NodeExecutionStats from backend.sdk import ( Block, BlockCategory, @@ -110,8 +111,10 @@ class DataForSeoKeywordSuggestionsBlock(Block): test_output=[ ( "suggestion", - lambda x: hasattr(x, "keyword") - and x.keyword == "digital marketing strategy", + lambda x: ( + hasattr(x, "keyword") + and x.keyword == "digital marketing strategy" + ), ), ("suggestions", lambda x: isinstance(x, list) and len(x) == 1), ("total_count", 1), @@ -167,6 +170,16 @@ class DataForSeoKeywordSuggestionsBlock(Block): results = await self._fetch_keyword_suggestions(client, input_data) + # DataForSEO reports per-task USD cost on the response. Feed it + # into NodeExecutionStats so the COST_USD resolver bills the + # real provider spend at reconciliation time. + self.merge_stats( + NodeExecutionStats( + provider_cost=client.last_cost_usd, + provider_cost_type="cost_usd", + ) + ) + # Process and format the results suggestions = [] if results and len(results) > 0: diff --git a/autogpt_platform/backend/backend/blocks/dataforseo/related_keywords.py b/autogpt_platform/backend/backend/blocks/dataforseo/related_keywords.py index 0757cb6507..711f5ea5ef 100644 --- a/autogpt_platform/backend/backend/blocks/dataforseo/related_keywords.py +++ b/autogpt_platform/backend/backend/blocks/dataforseo/related_keywords.py @@ -4,6 +4,7 @@ DataForSEO Google Related Keywords block. from typing import Any, Dict, List, Optional +from backend.data.model import NodeExecutionStats from backend.sdk import ( Block, BlockCategory, @@ -177,6 +178,16 @@ class DataForSeoRelatedKeywordsBlock(Block): results = await self._fetch_related_keywords(client, input_data) + # DataForSEO reports per-task USD cost on the response. Feed it + # into NodeExecutionStats so the COST_USD resolver bills the + # real provider spend at reconciliation time. + self.merge_stats( + NodeExecutionStats( + provider_cost=client.last_cost_usd, + provider_cost_type="cost_usd", + ) + ) + # Process and format the results related_keywords = [] if results and len(results) > 0: diff --git a/autogpt_platform/backend/backend/blocks/exa/_config.py b/autogpt_platform/backend/backend/blocks/exa/_config.py index bca636b2a8..31a37ba93b 100644 --- a/autogpt_platform/backend/backend/blocks/exa/_config.py +++ b/autogpt_platform/backend/backend/blocks/exa/_config.py @@ -11,6 +11,11 @@ exa = ( ProviderBuilder("exa") .with_api_key("EXA_API_KEY", "Exa API Key") .with_webhook_manager(ExaWebhookManager) - .with_base_cost(1, BlockCostType.RUN) + # Exa returns `cost_dollars.total` on every response and ExaSearchBlock + # (plus ~45 sibling blocks that share this provider config) already + # populates NodeExecutionStats.provider_cost with it. Bill 100 credits + # per USD (~$0.01/credit): cheap searches stay at 1–2 credits, a Deep + # Research run at $0.20 lands at 20 credits, matching provider spend. + .with_base_cost(100, BlockCostType.COST_USD) .build() ) diff --git a/autogpt_platform/backend/backend/blocks/exa/answers.py b/autogpt_platform/backend/backend/blocks/exa/answers.py index 9033d6b5f8..1017346e05 100644 --- a/autogpt_platform/backend/backend/blocks/exa/answers.py +++ b/autogpt_platform/backend/backend/blocks/exa/answers.py @@ -17,6 +17,7 @@ from backend.sdk import ( ) from ._config import exa +from .helpers import merge_exa_cost class AnswerCitation(BaseModel): @@ -111,3 +112,7 @@ class ExaAnswerBlock(Block): yield "citations", citations for citation in citations: yield "citation", citation + + # Current SDK AnswerResponse dataclass omits cost_dollars; helper + # no-ops today, but keeps billing wired when exa_py adds the field. + merge_exa_cost(self, response) diff --git a/autogpt_platform/backend/backend/blocks/exa/code_context.py b/autogpt_platform/backend/backend/blocks/exa/code_context.py index 2855c1dc4a..c57844372b 100644 --- a/autogpt_platform/backend/backend/blocks/exa/code_context.py +++ b/autogpt_platform/backend/backend/blocks/exa/code_context.py @@ -9,7 +9,6 @@ from typing import Union from pydantic import BaseModel -from backend.data.model import NodeExecutionStats from backend.sdk import ( APIKeyCredentials, Block, @@ -23,6 +22,7 @@ from backend.sdk import ( ) from ._config import exa +from .helpers import merge_exa_cost class CodeContextResponse(BaseModel): @@ -118,9 +118,5 @@ class ExaCodeContextBlock(Block): yield "search_time", context.search_time yield "output_tokens", context.output_tokens - # Parse cost_dollars (API returns as string, e.g. "0.005") - try: - cost_usd = float(context.cost_dollars) - self.merge_stats(NodeExecutionStats(provider_cost=cost_usd)) - except (ValueError, TypeError): - pass + # API returns costDollars as a bare numeric string like "0.005". + merge_exa_cost(self, data) diff --git a/autogpt_platform/backend/backend/blocks/exa/contents.py b/autogpt_platform/backend/backend/blocks/exa/contents.py index 8b2deaf036..b346cd746d 100644 --- a/autogpt_platform/backend/backend/blocks/exa/contents.py +++ b/autogpt_platform/backend/backend/blocks/exa/contents.py @@ -4,7 +4,6 @@ from typing import Optional from exa_py import AsyncExa from pydantic import BaseModel -from backend.data.model import NodeExecutionStats from backend.sdk import ( APIKeyCredentials, Block, @@ -24,6 +23,7 @@ from .helpers import ( HighlightSettings, LivecrawlTypes, SummarySettings, + merge_exa_cost, ) @@ -224,6 +224,4 @@ class ExaContentsBlock(Block): if response.cost_dollars: yield "cost_dollars", response.cost_dollars - self.merge_stats( - NodeExecutionStats(provider_cost=response.cost_dollars.total) - ) + merge_exa_cost(self, response) diff --git a/autogpt_platform/backend/backend/blocks/exa/cost_tracking_test.py b/autogpt_platform/backend/backend/blocks/exa/cost_tracking_test.py index 1ee395e539..96d161c6b8 100644 --- a/autogpt_platform/backend/backend/blocks/exa/cost_tracking_test.py +++ b/autogpt_platform/backend/backend/blocks/exa/cost_tracking_test.py @@ -143,7 +143,9 @@ class TestExaContentsCostTracking: mock_exa_cls.return_value = mock_exa async for _ in block.run( - block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type] + block.Input( + urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT + ), # type: ignore[arg-type] credentials=TEST_CREDENTIALS, ): pass @@ -172,7 +174,9 @@ class TestExaContentsCostTracking: mock_exa_cls.return_value = mock_exa async for _ in block.run( - block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type] + block.Input( + urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT + ), # type: ignore[arg-type] credentials=TEST_CREDENTIALS, ): pass @@ -201,7 +205,9 @@ class TestExaContentsCostTracking: mock_exa_cls.return_value = mock_exa async for _ in block.run( - block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type] + block.Input( + urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT + ), # type: ignore[arg-type] credentials=TEST_CREDENTIALS, ): pass @@ -297,7 +303,9 @@ class TestExaSimilarCostTracking: mock_exa_cls.return_value = mock_exa async for _ in block.run( - block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type] + block.Input( + url="https://example.com", credentials=TEST_CREDENTIALS_INPUT + ), # type: ignore[arg-type] credentials=TEST_CREDENTIALS, ): pass @@ -326,7 +334,9 @@ class TestExaSimilarCostTracking: mock_exa_cls.return_value = mock_exa async for _ in block.run( - block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type] + block.Input( + url="https://example.com", credentials=TEST_CREDENTIALS_INPUT + ), # type: ignore[arg-type] credentials=TEST_CREDENTIALS, ): pass diff --git a/autogpt_platform/backend/backend/blocks/exa/helpers.py b/autogpt_platform/backend/backend/blocks/exa/helpers.py index f31f01c78a..a6049f0879 100644 --- a/autogpt_platform/backend/backend/blocks/exa/helpers.py +++ b/autogpt_platform/backend/backend/blocks/exa/helpers.py @@ -1,7 +1,8 @@ from enum import Enum from typing import Any, Dict, Literal, Optional, Union -from backend.sdk import BaseModel, MediaFileType, SchemaField +from backend.data.model import NodeExecutionStats +from backend.sdk import BaseModel, Block, MediaFileType, SchemaField class LivecrawlTypes(str, Enum): @@ -319,7 +320,7 @@ class CostDollars(BaseModel): # Helper functions for payload processing def process_text_field( - text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None] + text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None], ) -> Optional[Union[bool, Dict[str, Any]]]: """Process text field for API payload.""" if text is None: @@ -400,7 +401,7 @@ def process_contents_settings(contents: Optional[ContentSettings]) -> Dict[str, def process_context_field( - context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None] + context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None], ) -> Optional[Union[bool, Dict[str, int]]]: """Process context field for API payload.""" if context is None: @@ -448,3 +449,65 @@ def add_optional_fields( payload[api_field] = value.value else: payload[api_field] = value + + +def extract_exa_cost_usd(response: Any) -> Optional[float]: + """Return ``cost_dollars.total`` (USD) from an Exa SDK response, or None. + + Handles dataclass/pydantic responses (``response.cost_dollars.total``), + dicts with camelCase keys (``response["costDollars"]["total"]``), dicts + with snake_case keys, and bare numeric strings. Returns None whenever the + shape is missing cost info — the caller then skips merge_stats. + """ + if response is None: + return None + + # Dataclass / pydantic: response.cost_dollars + cost_obj = getattr(response, "cost_dollars", None) + + # Dict payloads: try both camelCase and snake_case + if cost_obj is None and isinstance(response, dict): + cost_obj = response.get("costDollars") or response.get("cost_dollars") + + if cost_obj is None: + return None + + # Already a scalar (code_context endpoint returns a string) + if isinstance(cost_obj, (int, float)): + return max(0.0, float(cost_obj)) + if isinstance(cost_obj, str): + try: + return max(0.0, float(cost_obj)) + except ValueError: + return None + + # Nested object/dict: grab the `total` field + total = getattr(cost_obj, "total", None) + if total is None and isinstance(cost_obj, dict): + total = cost_obj.get("total") + + if total is None: + return None + + try: + return max(0.0, float(total)) + except (TypeError, ValueError): + return None + + +def merge_exa_cost(block: Block, response: Any) -> None: + """Pull ``cost_dollars.total`` off an Exa response and merge it into stats. + + No-op when the response shape has no cost info (e.g. webset CRUD where + the SDK does not expose per-call pricing) — emission happens only when + Exa actually reports a USD amount. + """ + cost_usd = extract_exa_cost_usd(response) + if cost_usd is None: + return + block.merge_stats( + NodeExecutionStats( + provider_cost=cost_usd, + provider_cost_type="cost_usd", + ) + ) diff --git a/autogpt_platform/backend/backend/blocks/exa/helpers_cost_test.py b/autogpt_platform/backend/backend/blocks/exa/helpers_cost_test.py new file mode 100644 index 0000000000..9c321a7142 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/exa/helpers_cost_test.py @@ -0,0 +1,65 @@ +"""Unit tests for exa/helpers cost-extraction + merge helpers.""" + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from backend.blocks.exa.helpers import extract_exa_cost_usd, merge_exa_cost +from backend.data.model import NodeExecutionStats + + +@pytest.mark.parametrize( + "response, expected", + [ + # Dataclass / SimpleNamespace with cost_dollars.total + (SimpleNamespace(cost_dollars=SimpleNamespace(total=0.05)), 0.05), + # Dict camelCase + ({"costDollars": {"total": 0.10}}, 0.10), + # Dict snake_case + ({"cost_dollars": {"total": 0.07}}, 0.07), + # code_context endpoint shape: plain numeric string + (SimpleNamespace(cost_dollars="0.005"), 0.005), + # Scalar float on cost_dollars directly + (SimpleNamespace(cost_dollars=0.02), 0.02), + # Scalar int on cost_dollars + (SimpleNamespace(cost_dollars=3), 3.0), + # Missing cost info — returns None + ({}, None), + (SimpleNamespace(other="foo"), None), + (None, None), + # Nested total=None + (SimpleNamespace(cost_dollars=SimpleNamespace(total=None)), None), + # Invalid numeric string + (SimpleNamespace(cost_dollars="not-a-number"), None), + # Negative values clamp to 0 + (SimpleNamespace(cost_dollars=SimpleNamespace(total=-1.0)), 0.0), + ], +) +def test_extract_exa_cost_usd_handles_all_shapes(response, expected): + assert extract_exa_cost_usd(response) == expected + + +def test_merge_exa_cost_emits_stats_when_cost_present(): + block = MagicMock() + response = SimpleNamespace(cost_dollars=SimpleNamespace(total=0.0421)) + merge_exa_cost(block, response) + + block.merge_stats.assert_called_once() + stats: NodeExecutionStats = block.merge_stats.call_args.args[0] + assert stats.provider_cost == pytest.approx(0.0421) + assert stats.provider_cost_type == "cost_usd" + + +def test_merge_exa_cost_noops_when_no_cost(): + """Webset CRUD endpoints don't surface cost_dollars today — the helper + must silently skip instead of emitting a 0-cost telemetry record.""" + block = MagicMock() + merge_exa_cost(block, SimpleNamespace(other_field="nothing")) + block.merge_stats.assert_not_called() + + +def test_merge_exa_cost_noops_when_response_is_none(): + block = MagicMock() + merge_exa_cost(block, None) + block.merge_stats.assert_not_called() diff --git a/autogpt_platform/backend/backend/blocks/exa/research.py b/autogpt_platform/backend/backend/blocks/exa/research.py index 575a88cc01..91693bbe0d 100644 --- a/autogpt_platform/backend/backend/blocks/exa/research.py +++ b/autogpt_platform/backend/backend/blocks/exa/research.py @@ -12,7 +12,6 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel -from backend.data.model import NodeExecutionStats from backend.sdk import ( APIKeyCredentials, Block, @@ -26,6 +25,7 @@ from backend.sdk import ( ) from ._config import exa +from .helpers import merge_exa_cost class ResearchModel(str, Enum): @@ -233,11 +233,7 @@ class ExaCreateResearchBlock(Block): if research.cost_dollars: yield "cost_total", research.cost_dollars.total - self.merge_stats( - NodeExecutionStats( - provider_cost=research.cost_dollars.total - ) - ) + merge_exa_cost(self, research) return await asyncio.sleep(check_interval) @@ -352,9 +348,7 @@ class ExaGetResearchBlock(Block): yield "cost_searches", research.cost_dollars.num_searches yield "cost_pages", research.cost_dollars.num_pages yield "cost_reasoning_tokens", research.cost_dollars.reasoning_tokens - self.merge_stats( - NodeExecutionStats(provider_cost=research.cost_dollars.total) - ) + merge_exa_cost(self, research) yield "error_message", research.error @@ -441,9 +435,7 @@ class ExaWaitForResearchBlock(Block): if research.cost_dollars: yield "cost_total", research.cost_dollars.total - self.merge_stats( - NodeExecutionStats(provider_cost=research.cost_dollars.total) - ) + merge_exa_cost(self, research) return diff --git a/autogpt_platform/backend/backend/blocks/exa/search.py b/autogpt_platform/backend/backend/blocks/exa/search.py index 5d9e99698f..4b17048707 100644 --- a/autogpt_platform/backend/backend/blocks/exa/search.py +++ b/autogpt_platform/backend/backend/blocks/exa/search.py @@ -4,7 +4,6 @@ from typing import Optional from exa_py import AsyncExa -from backend.data.model import NodeExecutionStats from backend.sdk import ( APIKeyCredentials, Block, @@ -21,6 +20,7 @@ from .helpers import ( ContentSettings, CostDollars, ExaSearchResults, + merge_exa_cost, process_contents_settings, ) @@ -207,6 +207,4 @@ class ExaSearchBlock(Block): if response.cost_dollars: yield "cost_dollars", response.cost_dollars - self.merge_stats( - NodeExecutionStats(provider_cost=response.cost_dollars.total) - ) + merge_exa_cost(self, response) diff --git a/autogpt_platform/backend/backend/blocks/exa/similar.py b/autogpt_platform/backend/backend/blocks/exa/similar.py index 004dfec4d6..9a162480b4 100644 --- a/autogpt_platform/backend/backend/blocks/exa/similar.py +++ b/autogpt_platform/backend/backend/blocks/exa/similar.py @@ -3,7 +3,6 @@ from typing import Optional from exa_py import AsyncExa -from backend.data.model import NodeExecutionStats from backend.sdk import ( APIKeyCredentials, Block, @@ -20,6 +19,7 @@ from .helpers import ( ContentSettings, CostDollars, ExaSearchResults, + merge_exa_cost, process_contents_settings, ) @@ -168,6 +168,4 @@ class ExaFindSimilarBlock(Block): if response.cost_dollars: yield "cost_dollars", response.cost_dollars - self.merge_stats( - NodeExecutionStats(provider_cost=response.cost_dollars.total) - ) + merge_exa_cost(self, response) diff --git a/autogpt_platform/backend/backend/blocks/exa/websets.py b/autogpt_platform/backend/backend/blocks/exa/websets.py index ce623ad410..99bbc64c57 100644 --- a/autogpt_platform/backend/backend/blocks/exa/websets.py +++ b/autogpt_platform/backend/backend/blocks/exa/websets.py @@ -39,6 +39,7 @@ from backend.sdk import ( ) from ._config import exa +from .helpers import merge_exa_cost class SearchEntityType(str, Enum): @@ -394,6 +395,7 @@ class ExaCreateWebsetBlock(Block): metadata=input_data.metadata, ) ) + merge_exa_cost(self, webset) webset_result = Webset.model_validate(webset.model_dump(by_alias=True)) @@ -404,6 +406,7 @@ class ExaCreateWebsetBlock(Block): timeout=input_data.polling_timeout, poll_interval=5, ) + merge_exa_cost(self, final_webset) completion_time = time.time() - start_time item_count = 0 @@ -479,6 +482,7 @@ class ExaCreateOrFindWebsetBlock(Block): try: webset = await aexa.websets.get(id=input_data.external_id) + merge_exa_cost(self, webset) webset_result = Webset.model_validate(webset.model_dump(by_alias=True)) yield "webset", webset_result @@ -501,6 +505,7 @@ class ExaCreateOrFindWebsetBlock(Block): metadata=input_data.metadata, ) ) + merge_exa_cost(self, webset) webset_result = Webset.model_validate(webset.model_dump(by_alias=True)) @@ -555,6 +560,7 @@ class ExaUpdateWebsetBlock(Block): payload["metadata"] = input_data.metadata sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload) + merge_exa_cost(self, sdk_webset) status_str = ( sdk_webset.status.value @@ -566,8 +572,9 @@ class ExaUpdateWebsetBlock(Block): yield "status", status_str yield "external_id", sdk_webset.external_id yield "metadata", sdk_webset.metadata or {} - yield "updated_at", ( - sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else "" + yield ( + "updated_at", + (sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else ""), ) @@ -621,6 +628,7 @@ class ExaListWebsetsBlock(Block): cursor=input_data.cursor, limit=input_data.limit, ) + merge_exa_cost(self, response) websets_data = [ w.model_dump(by_alias=True, exclude_none=True) for w in response.data @@ -679,6 +687,7 @@ class ExaGetWebsetBlock(Block): aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) sdk_webset = await aexa.websets.get(id=input_data.webset_id) + merge_exa_cost(self, sdk_webset) status_str = ( sdk_webset.status.value @@ -706,11 +715,13 @@ class ExaGetWebsetBlock(Block): yield "enrichments", enrichments_data yield "monitors", monitors_data yield "metadata", sdk_webset.metadata or {} - yield "created_at", ( - sdk_webset.created_at.isoformat() if sdk_webset.created_at else "" + yield ( + "created_at", + (sdk_webset.created_at.isoformat() if sdk_webset.created_at else ""), ) - yield "updated_at", ( - sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else "" + yield ( + "updated_at", + (sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else ""), ) @@ -749,6 +760,7 @@ class ExaDeleteWebsetBlock(Block): aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) deleted_webset = await aexa.websets.delete(id=input_data.webset_id) + merge_exa_cost(self, deleted_webset) status_str = ( deleted_webset.status.value @@ -799,6 +811,7 @@ class ExaCancelWebsetBlock(Block): aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) canceled_webset = await aexa.websets.cancel(id=input_data.webset_id) + merge_exa_cost(self, canceled_webset) status_str = ( canceled_webset.status.value @@ -969,6 +982,7 @@ class ExaPreviewWebsetBlock(Block): payload["entity"] = entity sdk_preview = await aexa.websets.preview(params=payload) + merge_exa_cost(self, sdk_preview) preview = PreviewWebsetModel.from_sdk(sdk_preview) @@ -1052,6 +1066,7 @@ class ExaWebsetStatusBlock(Block): aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) webset = await aexa.websets.get(id=input_data.webset_id) + merge_exa_cost(self, webset) status = ( webset.status.value @@ -1186,6 +1201,7 @@ class ExaWebsetSummaryBlock(Block): aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) webset = await aexa.websets.get(id=input_data.webset_id) + merge_exa_cost(self, webset) # Extract basic info webset_id = webset.id @@ -1214,6 +1230,7 @@ class ExaWebsetSummaryBlock(Block): items_response = await aexa.websets.items.list( webset_id=input_data.webset_id, limit=input_data.sample_size ) + merge_exa_cost(self, items_response) sample_items_data = [ item.model_dump(by_alias=True, exclude_none=True) for item in items_response.data @@ -1363,6 +1380,7 @@ class ExaWebsetReadyCheckBlock(Block): # Get webset details webset = await aexa.websets.get(id=input_data.webset_id) + merge_exa_cost(self, webset) status = ( webset.status.value diff --git a/autogpt_platform/backend/backend/blocks/exa/websets_enrichment.py b/autogpt_platform/backend/backend/blocks/exa/websets_enrichment.py index f136b996b9..f442764bfc 100644 --- a/autogpt_platform/backend/backend/blocks/exa/websets_enrichment.py +++ b/autogpt_platform/backend/backend/blocks/exa/websets_enrichment.py @@ -25,6 +25,7 @@ from backend.sdk import ( ) from ._config import exa +from .helpers import merge_exa_cost # Mirrored model for stability @@ -205,6 +206,7 @@ class ExaCreateEnrichmentBlock(Block): sdk_enrichment = await aexa.websets.enrichments.create( webset_id=input_data.webset_id, params=payload ) + merge_exa_cost(self, sdk_enrichment) enrichment_id = sdk_enrichment.id status = ( @@ -226,6 +228,7 @@ class ExaCreateEnrichmentBlock(Block): current_enrich = await aexa.websets.enrichments.get( webset_id=input_data.webset_id, id=enrichment_id ) + merge_exa_cost(self, current_enrich) current_status = ( current_enrich.status.value if hasattr(current_enrich.status, "value") @@ -235,6 +238,7 @@ class ExaCreateEnrichmentBlock(Block): if current_status in ["completed", "failed", "cancelled"]: # Estimate items from webset searches webset = await aexa.websets.get(id=input_data.webset_id) + merge_exa_cost(self, webset) if webset.searches: for search in webset.searches: if search.progress: @@ -332,6 +336,7 @@ class ExaGetEnrichmentBlock(Block): sdk_enrichment = await aexa.websets.enrichments.get( webset_id=input_data.webset_id, id=input_data.enrichment_id ) + merge_exa_cost(self, sdk_enrichment) enrichment = WebsetEnrichmentModel.from_sdk(sdk_enrichment) @@ -425,6 +430,7 @@ class ExaUpdateEnrichmentBlock(Block): try: response = await Requests().patch(url, headers=headers, json=payload) data = response.json() + # PATCH /websets/{id}/enrichments/{id} doesn't return costDollars. yield "enrichment_id", data.get("id", "") yield "status", data.get("status", "") @@ -477,6 +483,7 @@ class ExaDeleteEnrichmentBlock(Block): deleted_enrichment = await aexa.websets.enrichments.delete( webset_id=input_data.webset_id, id=input_data.enrichment_id ) + merge_exa_cost(self, deleted_enrichment) yield "enrichment_id", deleted_enrichment.id yield "success", "true" @@ -528,12 +535,14 @@ class ExaCancelEnrichmentBlock(Block): canceled_enrichment = await aexa.websets.enrichments.cancel( webset_id=input_data.webset_id, id=input_data.enrichment_id ) + merge_exa_cost(self, canceled_enrichment) # Try to estimate how many items were enriched before cancellation items_enriched = 0 items_response = await aexa.websets.items.list( webset_id=input_data.webset_id, limit=100 ) + merge_exa_cost(self, items_response) for sdk_item in items_response.data: # Check if this enrichment is present diff --git a/autogpt_platform/backend/backend/blocks/exa/websets_import_export.py b/autogpt_platform/backend/backend/blocks/exa/websets_import_export.py index e5a6137ed4..a865ff4bf3 100644 --- a/autogpt_platform/backend/backend/blocks/exa/websets_import_export.py +++ b/autogpt_platform/backend/backend/blocks/exa/websets_import_export.py @@ -29,6 +29,7 @@ from backend.sdk import ( from ._config import exa from ._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT +from .helpers import merge_exa_cost # Mirrored model for stability - don't use SDK types directly in block outputs @@ -297,6 +298,7 @@ class ExaCreateImportBlock(Block): sdk_import = await aexa.websets.imports.create( params=payload, csv_data=input_data.csv_data ) + merge_exa_cost(self, sdk_import) import_obj = ImportModel.from_sdk(sdk_import) @@ -361,6 +363,7 @@ class ExaGetImportBlock(Block): aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id) + merge_exa_cost(self, sdk_import) import_obj = ImportModel.from_sdk(sdk_import) @@ -430,6 +433,7 @@ class ExaListImportsBlock(Block): cursor=input_data.cursor, limit=input_data.limit, ) + merge_exa_cost(self, response) # Convert SDK imports to our stable models imports = [ImportModel.from_sdk(i) for i in response.data] @@ -477,6 +481,7 @@ class ExaDeleteImportBlock(Block): deleted_import = await aexa.websets.imports.delete( import_id=input_data.import_id ) + merge_exa_cost(self, deleted_import) yield "import_id", deleted_import.id yield "success", "true" @@ -599,7 +604,7 @@ class ExaExportWebsetBlock(Block): try: all_items = [] - # Use SDK's list_all iterator to fetch items + # list_all paginates internally; cost_dollars is not surfaced per-page item_iterator = aexa.websets.items.list_all( webset_id=input_data.webset_id, limit=input_data.max_items ) diff --git a/autogpt_platform/backend/backend/blocks/exa/websets_items.py b/autogpt_platform/backend/backend/blocks/exa/websets_items.py index cdccb89b8d..cf9b0fc9a3 100644 --- a/autogpt_platform/backend/backend/blocks/exa/websets_items.py +++ b/autogpt_platform/backend/backend/blocks/exa/websets_items.py @@ -30,6 +30,7 @@ from backend.sdk import ( ) from ._config import exa +from .helpers import merge_exa_cost # Mirrored model for enrichment results @@ -181,6 +182,7 @@ class ExaGetWebsetItemBlock(Block): sdk_item = await aexa.websets.items.get( webset_id=input_data.webset_id, id=input_data.item_id ) + merge_exa_cost(self, sdk_item) item = WebsetItemModel.from_sdk(sdk_item) @@ -293,6 +295,7 @@ class ExaListWebsetItemsBlock(Block): cursor=input_data.cursor, limit=input_data.limit, ) + merge_exa_cost(self, response) items = [WebsetItemModel.from_sdk(item) for item in response.data] @@ -343,6 +346,7 @@ class ExaDeleteWebsetItemBlock(Block): deleted_item = await aexa.websets.items.delete( webset_id=input_data.webset_id, id=input_data.item_id ) + merge_exa_cost(self, deleted_item) yield "item_id", deleted_item.id yield "success", "true" @@ -404,6 +408,7 @@ class ExaBulkWebsetItemsBlock(Block): aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) all_items: List[WebsetItemModel] = [] + # list_all paginates internally; cost_dollars is not surfaced per-page item_iterator = aexa.websets.items.list_all( webset_id=input_data.webset_id, limit=input_data.max_items ) @@ -476,6 +481,7 @@ class ExaWebsetItemsSummaryBlock(Block): aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) webset = await aexa.websets.get(id=input_data.webset_id) + merge_exa_cost(self, webset) entity_type = "unknown" if webset.searches: @@ -498,6 +504,7 @@ class ExaWebsetItemsSummaryBlock(Block): items_response = await aexa.websets.items.list( webset_id=input_data.webset_id, limit=input_data.sample_size ) + merge_exa_cost(self, items_response) # Convert to our stable models sample_items = [ WebsetItemModel.from_sdk(item) for item in items_response.data @@ -574,6 +581,7 @@ class ExaGetNewItemsBlock(Block): cursor=input_data.since_cursor, limit=input_data.max_items, ) + merge_exa_cost(self, response) # Convert SDK items to our stable models new_items = [WebsetItemModel.from_sdk(item) for item in response.data] diff --git a/autogpt_platform/backend/backend/blocks/exa/websets_monitor.py b/autogpt_platform/backend/backend/blocks/exa/websets_monitor.py index 8f9836965e..9e1a13243d 100644 --- a/autogpt_platform/backend/backend/blocks/exa/websets_monitor.py +++ b/autogpt_platform/backend/backend/blocks/exa/websets_monitor.py @@ -25,6 +25,7 @@ from backend.sdk import ( from ._config import exa from ._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT +from .helpers import merge_exa_cost # Mirrored model for stability - don't use SDK types directly in block outputs @@ -321,6 +322,7 @@ class ExaCreateMonitorBlock(Block): payload["metadata"] = input_data.metadata sdk_monitor = await aexa.websets.monitors.create(params=payload) + merge_exa_cost(self, sdk_monitor) monitor = MonitorModel.from_sdk(sdk_monitor) @@ -385,6 +387,7 @@ class ExaGetMonitorBlock(Block): aexa = AsyncExa(api_key=credentials.api_key.get_secret_value()) sdk_monitor = await aexa.websets.monitors.get(monitor_id=input_data.monitor_id) + merge_exa_cost(self, sdk_monitor) monitor = MonitorModel.from_sdk(sdk_monitor) @@ -479,6 +482,7 @@ class ExaUpdateMonitorBlock(Block): sdk_monitor = await aexa.websets.monitors.update( monitor_id=input_data.monitor_id, params=payload ) + merge_exa_cost(self, sdk_monitor) # Convert to our stable model monitor = MonitorModel.from_sdk(sdk_monitor) @@ -525,6 +529,7 @@ class ExaDeleteMonitorBlock(Block): deleted_monitor = await aexa.websets.monitors.delete( monitor_id=input_data.monitor_id ) + merge_exa_cost(self, deleted_monitor) yield "monitor_id", deleted_monitor.id yield "success", "true" @@ -586,6 +591,7 @@ class ExaListMonitorsBlock(Block): limit=input_data.limit, webset_id=input_data.webset_id, ) + merge_exa_cost(self, response) # Convert SDK monitors to our stable models monitors = [MonitorModel.from_sdk(m) for m in response.data] diff --git a/autogpt_platform/backend/backend/blocks/exa/websets_polling.py b/autogpt_platform/backend/backend/blocks/exa/websets_polling.py index f4168f1446..07cdcb0cec 100644 --- a/autogpt_platform/backend/backend/blocks/exa/websets_polling.py +++ b/autogpt_platform/backend/backend/blocks/exa/websets_polling.py @@ -25,6 +25,7 @@ from backend.sdk import ( ) from ._config import exa +from .helpers import merge_exa_cost # Import WebsetItemModel for use in enrichment samples # This is safe as websets_items doesn't import from websets_polling @@ -126,6 +127,7 @@ class ExaWaitForWebsetBlock(Block): timeout=input_data.timeout, poll_interval=input_data.check_interval, ) + merge_exa_cost(self, final_webset) elapsed = time.time() - start_time @@ -165,6 +167,7 @@ class ExaWaitForWebsetBlock(Block): while time.time() - start_time < input_data.timeout: # Get current webset status webset = await aexa.websets.get(id=input_data.webset_id) + merge_exa_cost(self, webset) current_status = ( webset.status.value if hasattr(webset.status, "value") @@ -210,6 +213,7 @@ class ExaWaitForWebsetBlock(Block): # Timeout reached elapsed = time.time() - start_time webset = await aexa.websets.get(id=input_data.webset_id) + merge_exa_cost(self, webset) final_status = ( webset.status.value if hasattr(webset.status, "value") @@ -348,6 +352,7 @@ class ExaWaitForSearchBlock(Block): search = await aexa.websets.searches.get( webset_id=input_data.webset_id, id=input_data.search_id ) + merge_exa_cost(self, search) # Extract status status = ( @@ -404,6 +409,7 @@ class ExaWaitForSearchBlock(Block): search = await aexa.websets.searches.get( webset_id=input_data.webset_id, id=input_data.search_id ) + merge_exa_cost(self, search) final_status = ( search.status.value if hasattr(search.status, "value") @@ -506,6 +512,7 @@ class ExaWaitForEnrichmentBlock(Block): enrichment = await aexa.websets.enrichments.get( webset_id=input_data.webset_id, id=input_data.enrichment_id ) + merge_exa_cost(self, enrichment) # Extract status status = ( @@ -523,16 +530,20 @@ class ExaWaitForEnrichmentBlock(Block): items_enriched = 0 if input_data.sample_results and status == "completed": - sample_data, items_enriched = ( - await self._get_sample_enrichments( - input_data.webset_id, input_data.enrichment_id, aexa - ) + ( + sample_data, + items_enriched, + ) = await self._get_sample_enrichments( + input_data.webset_id, input_data.enrichment_id, aexa ) yield "enrichment_id", input_data.enrichment_id yield "final_status", status yield "items_enriched", items_enriched - yield "enrichment_title", enrichment.title or enrichment.description or "" + yield ( + "enrichment_title", + enrichment.title or enrichment.description or "", + ) yield "elapsed_time", elapsed if input_data.sample_results: yield "sample_data", sample_data @@ -551,6 +562,7 @@ class ExaWaitForEnrichmentBlock(Block): enrichment = await aexa.websets.enrichments.get( webset_id=input_data.webset_id, id=input_data.enrichment_id ) + merge_exa_cost(self, enrichment) final_status = ( enrichment.status.value if hasattr(enrichment.status, "value") @@ -576,6 +588,7 @@ class ExaWaitForEnrichmentBlock(Block): """Get sample enriched data and count.""" # Get a few items to see enrichment results using SDK response = await aexa.websets.items.list(webset_id=webset_id, limit=5) + merge_exa_cost(self, response) sample_data: list[SampleEnrichmentModel] = [] enriched_count = 0 diff --git a/autogpt_platform/backend/backend/blocks/exa/websets_search.py b/autogpt_platform/backend/backend/blocks/exa/websets_search.py index 77184b6cdf..77ba59d98d 100644 --- a/autogpt_platform/backend/backend/blocks/exa/websets_search.py +++ b/autogpt_platform/backend/backend/blocks/exa/websets_search.py @@ -24,6 +24,7 @@ from backend.sdk import ( ) from ._config import exa +from .helpers import merge_exa_cost # Mirrored model for stability @@ -320,6 +321,7 @@ class ExaCreateWebsetSearchBlock(Block): sdk_search = await aexa.websets.searches.create( webset_id=input_data.webset_id, params=payload ) + merge_exa_cost(self, sdk_search) search_id = sdk_search.id status = ( @@ -353,6 +355,7 @@ class ExaCreateWebsetSearchBlock(Block): current_search = await aexa.websets.searches.get( webset_id=input_data.webset_id, id=search_id ) + merge_exa_cost(self, current_search) current_status = ( current_search.status.value if hasattr(current_search.status, "value") @@ -445,6 +448,7 @@ class ExaGetWebsetSearchBlock(Block): sdk_search = await aexa.websets.searches.get( webset_id=input_data.webset_id, id=input_data.search_id ) + merge_exa_cost(self, sdk_search) search = WebsetSearchModel.from_sdk(sdk_search) @@ -526,6 +530,7 @@ class ExaCancelWebsetSearchBlock(Block): canceled_search = await aexa.websets.searches.cancel( webset_id=input_data.webset_id, id=input_data.search_id ) + merge_exa_cost(self, canceled_search) # Extract items found before cancellation items_found = 0 @@ -605,6 +610,7 @@ class ExaFindOrCreateSearchBlock(Block): # Get webset to check existing searches webset = await aexa.websets.get(id=input_data.webset_id) + merge_exa_cost(self, webset) # Look for existing search with same query existing_search = None @@ -639,6 +645,7 @@ class ExaFindOrCreateSearchBlock(Block): sdk_search = await aexa.websets.searches.create( webset_id=input_data.webset_id, params=payload ) + merge_exa_cost(self, sdk_search) search = WebsetSearchModel.from_sdk(sdk_search) diff --git a/autogpt_platform/backend/backend/blocks/firecrawl/_config.py b/autogpt_platform/backend/backend/blocks/firecrawl/_config.py index cc176c4a86..7a2ff95fe4 100644 --- a/autogpt_platform/backend/backend/blocks/firecrawl/_config.py +++ b/autogpt_platform/backend/backend/blocks/firecrawl/_config.py @@ -1,8 +1,14 @@ from backend.sdk import BlockCostType, ProviderBuilder +# Firecrawl bills in its own credits (1 credit ≈ $0.001). Each block's +# run() estimates USD spend from the operation (pages scraped, limit, +# credits_used on ExtractResponse) and merge_stats populates +# NodeExecutionStats.provider_cost before billing reconciliation. 1000 +# platform credits per USD means 1 platform credit per Firecrawl credit +# — roughly matches our existing per-call tier for single-page scrape. firecrawl = ( ProviderBuilder("firecrawl") .with_api_key("FIRECRAWL_API_KEY", "Firecrawl API Key") - .with_base_cost(1, BlockCostType.RUN) + .with_base_cost(1000, BlockCostType.COST_USD) .build() ) diff --git a/autogpt_platform/backend/backend/blocks/firecrawl/crawl.py b/autogpt_platform/backend/backend/blocks/firecrawl/crawl.py index eced461a8a..0c88b85e59 100644 --- a/autogpt_platform/backend/backend/blocks/firecrawl/crawl.py +++ b/autogpt_platform/backend/backend/blocks/firecrawl/crawl.py @@ -4,6 +4,7 @@ from firecrawl import FirecrawlApp from firecrawl.v2.types import ScrapeOptions from backend.blocks.firecrawl._api import ScrapeFormat +from backend.data.model import NodeExecutionStats from backend.sdk import ( APIKeyCredentials, Block, @@ -86,6 +87,14 @@ class FirecrawlCrawlBlock(Block): wait_for=input_data.wait_for, ), ) + # Firecrawl bills 1 credit (~$0.001) per crawled page. crawl_result.data + # is the list of scraped pages actually returned. + pages = len(crawl_result.data) if crawl_result.data else 0 + self.merge_stats( + NodeExecutionStats( + provider_cost=pages * 0.001, provider_cost_type="cost_usd" + ) + ) yield "data", crawl_result.data for data in crawl_result.data: diff --git a/autogpt_platform/backend/backend/blocks/firecrawl/extract.py b/autogpt_platform/backend/backend/blocks/firecrawl/extract.py index e5fd5ec9f3..c86feb1b09 100755 --- a/autogpt_platform/backend/backend/blocks/firecrawl/extract.py +++ b/autogpt_platform/backend/backend/blocks/firecrawl/extract.py @@ -2,25 +2,22 @@ from typing import Any from firecrawl import FirecrawlApp +from backend.data.model import NodeExecutionStats from backend.sdk import ( APIKeyCredentials, Block, BlockCategory, - BlockCost, - BlockCostType, BlockOutput, BlockSchemaInput, BlockSchemaOutput, CredentialsMetaInput, SchemaField, - cost, ) from backend.util.exceptions import BlockExecutionError from ._config import firecrawl -@cost(BlockCost(2, BlockCostType.RUN)) class FirecrawlExtractBlock(Block): class Input(BlockSchemaInput): credentials: CredentialsMetaInput = firecrawl.credentials_field() @@ -74,4 +71,13 @@ class FirecrawlExtractBlock(Block): block_id=self.id, ) from e + # Firecrawl surfaces actual credit spend on extract responses + # (credits_used). 1 Firecrawl credit ≈ $0.001. + credits_used = getattr(extract_result, "credits_used", None) or 0 + self.merge_stats( + NodeExecutionStats( + provider_cost=credits_used * 0.001, + provider_cost_type="cost_usd", + ) + ) yield "data", extract_result.data diff --git a/autogpt_platform/backend/backend/blocks/firecrawl/map.py b/autogpt_platform/backend/backend/blocks/firecrawl/map.py index e2e04adac0..9d24da7237 100644 --- a/autogpt_platform/backend/backend/blocks/firecrawl/map.py +++ b/autogpt_platform/backend/backend/blocks/firecrawl/map.py @@ -2,6 +2,7 @@ from typing import Any from firecrawl import FirecrawlApp +from backend.data.model import NodeExecutionStats from backend.sdk import ( APIKeyCredentials, Block, @@ -50,6 +51,10 @@ class FirecrawlMapWebsiteBlock(Block): map_result = app.map( url=input_data.url, ) + # Firecrawl bills 1 credit (~$0.001) per map request. + self.merge_stats( + NodeExecutionStats(provider_cost=0.001, provider_cost_type="cost_usd") + ) # Convert SearchResult objects to dicts results_data = [ diff --git a/autogpt_platform/backend/backend/blocks/firecrawl/scrape.py b/autogpt_platform/backend/backend/blocks/firecrawl/scrape.py index 2c1a68d6d9..f7923cf07c 100644 --- a/autogpt_platform/backend/backend/blocks/firecrawl/scrape.py +++ b/autogpt_platform/backend/backend/blocks/firecrawl/scrape.py @@ -3,6 +3,7 @@ from typing import Any from firecrawl import FirecrawlApp from backend.blocks.firecrawl._api import ScrapeFormat +from backend.data.model import NodeExecutionStats from backend.sdk import ( APIKeyCredentials, Block, @@ -81,6 +82,11 @@ class FirecrawlScrapeBlock(Block): max_age=input_data.max_age, wait_for=input_data.wait_for, ) + # Firecrawl bills 1 credit (~$0.001) per scraped page; scrape is a + # single-page operation. + self.merge_stats( + NodeExecutionStats(provider_cost=0.001, provider_cost_type="cost_usd") + ) yield "data", scrape_result for f in input_data.formats: diff --git a/autogpt_platform/backend/backend/blocks/firecrawl/search.py b/autogpt_platform/backend/backend/blocks/firecrawl/search.py index a2769a0f96..3c14bcf905 100644 --- a/autogpt_platform/backend/backend/blocks/firecrawl/search.py +++ b/autogpt_platform/backend/backend/blocks/firecrawl/search.py @@ -4,6 +4,7 @@ from firecrawl import FirecrawlApp from firecrawl.v2.types import ScrapeOptions from backend.blocks.firecrawl._api import ScrapeFormat +from backend.data.model import NodeExecutionStats from backend.sdk import ( APIKeyCredentials, Block, @@ -68,6 +69,17 @@ class FirecrawlSearchBlock(Block): wait_for=input_data.wait_for, ), ) + # Firecrawl bills per returned web result (~1 credit each). The + # SearchResponse structure exposes `.web` when scrape_options was + # requested; fall back to `limit` as an upper bound estimate. + web_results = getattr(scrape_result, "web", None) or [] + billed_units = max(len(web_results), 1) + self.merge_stats( + NodeExecutionStats( + provider_cost=billed_units * 0.001, + provider_cost_type="cost_usd", + ) + ) yield "data", scrape_result if hasattr(scrape_result, "web") and scrape_result.web: for site in scrape_result.web: diff --git a/autogpt_platform/backend/backend/blocks/google/__init__.py b/autogpt_platform/backend/backend/blocks/google/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/autogpt_platform/backend/backend/blocks/google/_drive.py b/autogpt_platform/backend/backend/blocks/google/_drive.py index cb2b52821c..c5ecc55701 100644 --- a/autogpt_platform/backend/backend/blocks/google/_drive.py +++ b/autogpt_platform/backend/backend/blocks/google/_drive.py @@ -133,10 +133,21 @@ def GoogleDriveFileField( if allowed_mime_types: picker_config["allowed_mime_types"] = list(allowed_mime_types) + agent_builder_hint = ( + "At runtime, feed this from an AgentGoogleDriveFileInputBlock with " + "matching allowed_views. NEVER hardcode a file ID in input_default " + "(including one parsed from a Drive URL the user pasted in chat) — " + "only the picker attaches the _credentials_id needed for auth." + ) + return SchemaField( default=None, title=title, - description=description, + description=( + f"{description.rstrip('.')}. {agent_builder_hint}" + if description + else agent_builder_hint + ), placeholder=placeholder or "Select from Google Drive", # Use google-drive-picker format so frontend renders existing component format="google-drive-picker", diff --git a/autogpt_platform/backend/backend/blocks/google/sheets_test.py b/autogpt_platform/backend/backend/blocks/google/sheets_test.py new file mode 100644 index 0000000000..40a28f3cc7 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/google/sheets_test.py @@ -0,0 +1,129 @@ +"""Edge-case tests for Google Sheets block credential handling. + +These pin the contract for the systemic auto-credential None-guard in +``Block._execute()``: any block with an auto-credential field (via +``GoogleDriveFileField`` etc.) that's called without resolved +credentials must surface a clean, user-facing ``BlockExecutionError`` +— never a wrapped ``TypeError`` (missing required kwarg) or +``AttributeError`` deep in the provider SDK. +""" + +import pytest + +from backend.blocks.google.sheets import GoogleSheetsReadBlock +from backend.util.exceptions import BlockExecutionError + + +@pytest.mark.asyncio +async def test_sheets_read_missing_credentials_yields_clean_error(): + """Valid spreadsheet but no resolved credentials -> the systemic + None-guard in ``Block._execute()`` yields a ``Missing credentials`` + error before ``run()`` is entered.""" + block = GoogleSheetsReadBlock() + input_data = { + "spreadsheet": { + "id": "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms", + "name": "Test Spreadsheet", + "mimeType": "application/vnd.google-apps.spreadsheet", + }, + "range": "Sheet1!A1:B2", + } + + with pytest.raises(BlockExecutionError, match="Missing credentials"): + async for _ in block.execute(input_data): + pass + + +@pytest.mark.asyncio +async def test_sheets_read_no_spreadsheet_still_hits_credentials_guard(): + """When neither spreadsheet nor credentials are present, the + credentials guard fires first (it runs before we hand off to + ``run()``). The user-facing message should still be the clean + ``Missing credentials`` one, not an opaque ``TypeError``.""" + block = GoogleSheetsReadBlock() + input_data = {"range": "Sheet1!A1:B2"} # no spreadsheet, no credentials + + with pytest.raises(BlockExecutionError, match="Missing credentials"): + async for _ in block.execute(input_data): + pass + + +@pytest.mark.asyncio +async def test_sheets_read_upstream_chained_value_skips_guard(mocker): + """A spreadsheet value chained in from an upstream input block (e.g. + ``AgentGoogleDriveFileInputBlock``) carries a resolved + ``_credentials_id`` that ``_acquire_auto_credentials`` didn't have + visibility into at prep time. The systemic None-guard must NOT + preempt run() in that case — otherwise every chained Drive-picker + pattern crashes with a bogus ``Missing credentials`` error. + + We short-circuit past the guard by patching the Google API client + build; any error that escapes from run() is fine as long as the + ``Missing credentials`` message never surfaces.""" + # Patch out the real Google Sheets client build so we don't hit the + # network and can detect we reached the provider SDK. + mocker.patch( + "backend.blocks.google.sheets.build", + side_effect=RuntimeError("api-boundary-reached"), + ) + + block = GoogleSheetsReadBlock() + input_data = { + "spreadsheet": { + "_credentials_id": "upstream-chained-cred-id", + "id": "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms", + "name": "Upstream-chained sheet", + "mimeType": "application/vnd.google-apps.spreadsheet", + }, + "range": "Sheet1!A1:B2", + } + + with pytest.raises(Exception) as exc_info: + async for _ in block.execute(input_data): + pass + + # The guard should skip (chained data present) and let us reach run(), + # which then hits the patched provider-SDK boundary. A "Missing + # credentials" error here would mean the None-guard broke the + # documented AgentGoogleDriveFileInputBlock chaining pattern. + assert "Missing credentials" not in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_sheets_read_upstream_chained_with_explicit_none_cred_id_skips_guard( + mocker, +): + """Sentry HIGH regression (thread PRRT_kwDOJKSTjM58sJfA): the + documented chained-upstream pattern ships the spreadsheet dict with + ``_credentials_id=None`` — the executor fills in the resolved id + between prep time and ``run()``. The previous ``_base.py`` guard + used ``field_value.get("_credentials_id")`` and treated the falsy + ``None`` value as "missing", raising ``BlockExecutionError`` on + every chained graph. + + Pin the contract: the presence of the ``_credentials_id`` key — not + its truthiness — is what signals "trust the skip". A dict with + ``_credentials_id: None`` must not preempt run().""" + mocker.patch( + "backend.blocks.google.sheets.build", + side_effect=RuntimeError("api-boundary-reached"), + ) + + block = GoogleSheetsReadBlock() + input_data = { + "spreadsheet": { + "_credentials_id": None, # explicit None — chained-upstream shape + "id": "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms", + "name": "Upstream-chained sheet (None cred_id)", + "mimeType": "application/vnd.google-apps.spreadsheet", + }, + "range": "Sheet1!A1:B2", + } + + with pytest.raises(Exception) as exc_info: + async for _ in block.execute(input_data): + pass + + # The guard must not raise "Missing credentials" for this shape. + # We expect to reach run() and hit the patched provider-SDK boundary. + assert "Missing credentials" not in str(exc_info.value) diff --git a/autogpt_platform/backend/backend/blocks/io.py b/autogpt_platform/backend/backend/blocks/io.py index e72ee5c097..2ef9999da4 100644 --- a/autogpt_platform/backend/backend/blocks/io.py +++ b/autogpt_platform/backend/backend/blocks/io.py @@ -737,7 +737,22 @@ class AgentGoogleDriveFileInputBlock(AgentInputBlock): ) super().__init__( id="d3b32f15-6fd7-40e3-be52-e083f51b19a2", - description="Block for selecting a file from Google Drive.", + description=( + "Agent-level input for a Google Drive file. REQUIRED for any " + "agent that reads or writes a Drive file (Sheets, Docs, " + "Slides, or generic Drive) — the picker is the only source " + "of the _credentials_id needed at runtime, so consuming " + "blocks cannot receive a hardcoded ID. Set allowed_views to " + 'match the consumer: ["SPREADSHEETS"] for Sheets, ' + '["DOCUMENTS"] for Docs, ["PRESENTATIONS"] for Slides ' + "(leave default for generic Drive). Wire `result` to the " + "consumer block's Drive field and leave that field unset in " + "the consumer's input_default. Example link to a Google " + 'Sheets block: {"source_name": "result", "sink_name": ' + '"spreadsheet"} (use "document" for Docs, "presentation" ' + "for Slides). Use one input block per distinct file; " + "multiple consumers of the same file share it." + ), disabled=not config.enable_agent_input_subtype_blocks, input_schema=AgentGoogleDriveFileInputBlock.Input, output_schema=AgentGoogleDriveFileInputBlock.Output, diff --git a/autogpt_platform/backend/backend/blocks/jina/search.py b/autogpt_platform/backend/backend/blocks/jina/search.py index 007dd5bc12..5c2ebfb39f 100644 --- a/autogpt_platform/backend/backend/blocks/jina/search.py +++ b/autogpt_platform/backend/backend/blocks/jina/search.py @@ -15,7 +15,7 @@ from backend.blocks.jina._auth import ( JinaCredentialsInput, ) from backend.blocks.search import GetRequest -from backend.data.model import SchemaField +from backend.data.model import NodeExecutionStats, SchemaField from backend.util.exceptions import BlockExecutionError from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host @@ -70,6 +70,13 @@ class SearchTheWebBlock(Block, GetRequest): block_id=self.id, ) from e + # Jina Reader Search: $0.01/query on the paid tier. Fixed per-query + # cost; routed through COST_USD so the platform cost log records + # real USD spend (costMicrodollars) alongside the credit charge. + self.merge_stats( + NodeExecutionStats(provider_cost=0.01, provider_cost_type="cost_usd") + ) + # Output the search results yield "results", results @@ -128,10 +135,16 @@ class ExtractWebsiteContentBlock(Block, GetRequest): try: content = await self.get_request(url, json=False, headers=headers) except HTTPClientError as e: - yield "error", f"Client error ({e.status_code}) fetching {input_data.url}: {e}" + yield ( + "error", + f"Client error ({e.status_code}) fetching {input_data.url}: {e}", + ) return except HTTPServerError as e: - yield "error", f"Server error ({e.status_code}) fetching {input_data.url}: {e}" + yield ( + "error", + f"Server error ({e.status_code}) fetching {input_data.url}: {e}", + ) return except Exception as e: yield "error", f"Failed to fetch {input_data.url}: {e}" diff --git a/autogpt_platform/backend/backend/blocks/llm.py b/autogpt_platform/backend/backend/blocks/llm.py index 8543a03b69..86c8d96427 100644 --- a/autogpt_platform/backend/backend/blocks/llm.py +++ b/autogpt_platform/backend/backend/blocks/llm.py @@ -206,6 +206,10 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta): GROK_4_20_MULTI_AGENT = "x-ai/grok-4.20-multi-agent" GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1" KIMI_K2 = "moonshotai/kimi-k2" + KIMI_K2_0905 = "moonshotai/kimi-k2-0905" + KIMI_K2_5 = "moonshotai/kimi-k2.5" + KIMI_K2_6 = "moonshotai/kimi-k2.6" + KIMI_K2_THINKING = "moonshotai/kimi-k2-thinking" QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507" QWEN3_CODER = "qwen/qwen3-coder" # Z.ai (Zhipu) models @@ -646,6 +650,24 @@ MODEL_METADATA = { LlmModel.KIMI_K2: ModelMetadata( "open_router", 131000, 131000, "Kimi K2", "OpenRouter", "Moonshot AI", 1 ), + LlmModel.KIMI_K2_0905: ModelMetadata( + "open_router", 262144, 262144, "Kimi K2 0905", "OpenRouter", "Moonshot AI", 1 + ), + LlmModel.KIMI_K2_5: ModelMetadata( + "open_router", 262144, 262144, "Kimi K2.5", "OpenRouter", "Moonshot AI", 1 + ), + LlmModel.KIMI_K2_6: ModelMetadata( + "open_router", 262144, 262144, "Kimi K2.6", "OpenRouter", "Moonshot AI", 2 + ), + LlmModel.KIMI_K2_THINKING: ModelMetadata( + "open_router", + 262144, + 262144, + "Kimi K2 Thinking", + "OpenRouter", + "Moonshot AI", + 2, + ), LlmModel.QWEN3_235B_A22B_THINKING: ModelMetadata( "open_router", 262144, @@ -1602,6 +1624,11 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): llm_call_count=retry_count + 1, llm_retry_count=retry_count, provider_cost=total_provider_cost, + provider_cost_type=( + "cost_usd" + if total_provider_cost is not None + else None + ), ) ) yield "response", response_obj @@ -1623,6 +1650,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): llm_call_count=retry_count + 1, llm_retry_count=retry_count, provider_cost=total_provider_cost, + provider_cost_type=( + "cost_usd" if total_provider_cost is not None else None + ), ) ) yield "response", {"response": response_text} @@ -1657,7 +1687,12 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): # All retries exhausted or user-error break: persist accumulated cost so # the executor can still charge/report the spend even on failure. if total_provider_cost is not None: - self.merge_stats(NodeExecutionStats(provider_cost=total_provider_cost)) + self.merge_stats( + NodeExecutionStats( + provider_cost=total_provider_cost, + provider_cost_type="cost_usd", + ) + ) raise RuntimeError(error_feedback_message) def response_format_instructions( diff --git a/autogpt_platform/backend/backend/blocks/orchestrator.py b/autogpt_platform/backend/backend/blocks/orchestrator.py index b2a6df8481..5979f90dde 100644 --- a/autogpt_platform/backend/backend/blocks/orchestrator.py +++ b/autogpt_platform/backend/backend/blocks/orchestrator.py @@ -376,20 +376,12 @@ class OrchestratorBlock(Block): re-raise carve-out for this reason. """ - def extra_runtime_cost(self, execution_stats: NodeExecutionStats) -> int: - """Charge one extra runtime cost per LLM call beyond the first. - - In agent mode each iteration makes one LLM call. The first is already - covered by charge_usage(); this returns the number of additional - credits so the executor can bill the remaining calls post-completion. - - SDK-mode exemption: when the block runs via _execute_tools_sdk_mode, - the SDK manages its own conversation loop and only exposes aggregate - usage. We hardcode llm_call_count=1 there (the SDK does not report a - per-turn call count), so this method always returns 0 for SDK-mode - executions. Per-iteration billing does not apply to SDK mode. - """ - return max(0, execution_stats.llm_call_count - 1) + # OrchestratorBlock bills via BlockCostType.TOKENS + compute_token_credits, + # which aggregates input_token_count / output_token_count / cache_read / + # cache_creation across every LLM iteration into one post-flight charge. + # The per-iteration flat-fee path (Block.extra_runtime_cost → + # charge_extra_runtime_cost) would double-bill the same tokens, so + # OrchestratorBlock deliberately inherits the base-class no-op default. # MCP server name used by the Claude Code SDK execution mode. Keep in sync # with _create_graph_mcp_server and the MCP_PREFIX derivation in _execute_tools_sdk_mode. @@ -1189,10 +1181,14 @@ class OrchestratorBlock(Block): not execution_params.execution_context.dry_run and tool_node_stats.error is None ): + # Charge the sub-block for telemetry / wallet debit. The + # return value is intentionally discarded: on_node_execution + # above ran the sub-block against this graph's own + # graph_stats_pair (manager.py:659-668), so its cost already + # lands in graph_stats.cost on the sub-block's completion. + # Re-merging here would double-count in telemetry / UI / audit. try: - tool_cost, _ = await execution_processor.charge_node_usage( - node_exec_entry, - ) + await execution_processor.charge_node_usage(node_exec_entry) except InsufficientBalanceError: # IBE must propagate — see OrchestratorBlock class docstring. # Log the billing failure here so the discarded tool result @@ -1214,9 +1210,6 @@ class OrchestratorBlock(Block): "tool execution was successful", sink_node_id, ) - tool_cost = 0 - if tool_cost > 0: - self.merge_stats(NodeExecutionStats(extra_cost=tool_cost)) # Get outputs from database after execution completes using database manager client node_outputs = await db_client.get_execution_outputs_by_node_exec_id( diff --git a/autogpt_platform/backend/backend/blocks/perplexity.py b/autogpt_platform/backend/backend/blocks/perplexity.py index a8b137ce2b..0cdf29a3de 100644 --- a/autogpt_platform/backend/backend/blocks/perplexity.py +++ b/autogpt_platform/backend/backend/blocks/perplexity.py @@ -13,6 +13,7 @@ from backend.blocks._base import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.blocks.llm import extract_openrouter_cost from backend.data.block import BlockInput from backend.data.model import ( APIKeyCredentials, @@ -98,14 +99,23 @@ class PerplexityBlock(Block): return _sanitize_perplexity_model(v) @classmethod - def validate_data(cls, data: BlockInput) -> str | None: + def validate_data( + cls, + data: BlockInput, + exclude_fields: set[str] | None = None, + ) -> str | None: """Sanitize the model field before JSON schema validation so that invalid values are replaced with the default instead of raising a - BlockInputError.""" + BlockInputError. + + Signature matches ``BlockSchema.validate_data`` (including the + optional ``exclude_fields`` kwarg added for dry-run credential + bypass) so Pyright doesn't flag this as an incompatible override. + """ model_value = data.get("model") if model_value is not None: data["model"] = _sanitize_perplexity_model(model_value).value - return super().validate_data(data) + return super().validate_data(data, exclude_fields=exclude_fields) system_prompt: str = SchemaField( title="System Prompt", @@ -230,12 +240,17 @@ class PerplexityBlock(Block): if "message" in choice and "annotations" in choice["message"]: annotations = choice["message"]["annotations"] - # Update execution stats + # Update execution stats. ``execution_stats`` is instance state, + # so always reset token counters — a response without ``usage`` + # must not leak a previous run's tokens into ``PlatformCostLog``. + self.execution_stats.input_token_count = 0 + self.execution_stats.output_token_count = 0 if response.usage: self.execution_stats.input_token_count = response.usage.prompt_tokens self.execution_stats.output_token_count = ( response.usage.completion_tokens ) + self._record_openrouter_cost(response) return {"response": response_content, "annotations": annotations or []} @@ -243,6 +258,17 @@ class PerplexityBlock(Block): logger.error(f"Error calling Perplexity: {e}") raise + def _record_openrouter_cost(self, response: Any) -> None: + """Feed OpenRouter's ``x-total-cost`` USD into execution stats for + the COST_USD resolver. Tag as ``cost_usd`` only when the value is + concrete and positive — leaving it unset on None/0 keeps the + billing gap observable instead of silently floored to 0. + """ + cost_usd = extract_openrouter_cost(response) + self.execution_stats.provider_cost = cost_usd + if cost_usd is not None and cost_usd > 0: + self.execution_stats.provider_cost_type = "cost_usd" + async def run( self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs ) -> BlockOutput: diff --git a/autogpt_platform/backend/backend/blocks/pinecone.py b/autogpt_platform/backend/backend/blocks/pinecone.py index f882212ab2..270d224ecf 100644 --- a/autogpt_platform/backend/backend/blocks/pinecone.py +++ b/autogpt_platform/backend/backend/blocks/pinecone.py @@ -14,6 +14,7 @@ from backend.data.model import ( APIKeyCredentials, CredentialsField, CredentialsMetaInput, + NodeExecutionStats, SchemaField, ) from backend.integrations.providers import ProviderName @@ -160,10 +161,13 @@ class PineconeQueryBlock(Block): combined_text = "\n\n".join(texts) # Return both the raw matches and combined text - yield "results", { - "matches": results["matches"], - "combined_text": combined_text, - } + yield ( + "results", + { + "matches": results["matches"], + "combined_text": combined_text, + }, + ) yield "combined_results", combined_text except Exception as e: @@ -228,6 +232,13 @@ class PineconeInsertBlock(Block): ) idx.upsert(vectors=vectors, namespace=input_data.namespace) + self.merge_stats( + NodeExecutionStats( + provider_cost=float(len(vectors)), + provider_cost_type="items", + ) + ) + yield "upsert_response", "successfully upserted" except Exception as e: diff --git a/autogpt_platform/backend/backend/blocks/stagehand/_config.py b/autogpt_platform/backend/backend/blocks/stagehand/_config.py index 43ec6cd5ac..0bb609d664 100644 --- a/autogpt_platform/backend/backend/blocks/stagehand/_config.py +++ b/autogpt_platform/backend/backend/blocks/stagehand/_config.py @@ -1,8 +1,12 @@ from backend.sdk import BlockCostType, ProviderBuilder +# 1 credit per 3 walltime seconds. Block walltime proxies for the +# Browserbase session lifetime + the LLM call it issues. Interim until +# the block emits real provider_cost (USD) via merge_stats and migrates +# to COST_USD. stagehand = ( ProviderBuilder("stagehand") .with_api_key("STAGEHAND_API_KEY", "Stagehand API Key") - .with_base_cost(1, BlockCostType.RUN) + .with_base_cost(1, BlockCostType.SECOND, cost_divisor=3) .build() ) diff --git a/autogpt_platform/backend/backend/blocks/test/test_orchestrator_per_iteration_cost.py b/autogpt_platform/backend/backend/blocks/test/test_orchestrator_per_iteration_cost.py deleted file mode 100644 index 441bc08a42..0000000000 --- a/autogpt_platform/backend/backend/blocks/test/test_orchestrator_per_iteration_cost.py +++ /dev/null @@ -1,1020 +0,0 @@ -"""Tests for OrchestratorBlock per-iteration cost charging. - -The OrchestratorBlock in agent mode makes multiple LLM calls in a single -node execution. The executor uses ``Block.extra_runtime_cost`` to detect -this and charge ``base_cost * (llm_call_count - 1)`` extra credits after -the block completes. -""" - -import threading -from collections import defaultdict -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from backend.blocks._base import Block -from backend.blocks.orchestrator import ExecutionParams, OrchestratorBlock -from backend.data.execution import ExecutionContext, ExecutionStatus -from backend.data.model import NodeExecutionStats -from backend.executor import billing, manager -from backend.util.exceptions import InsufficientBalanceError - -# ── extra_runtime_cost hook ──────────────────────────────────────── - - -class _NoOpBlock(Block): - """Minimal concrete Block subclass that does not override extra_runtime_cost.""" - - def __init__(self): - super().__init__( - id="00000000-0000-0000-0000-000000000001", description="No-op test block" - ) - - def run(self, input_data, **kwargs): # type: ignore[override] - yield "out", {} - - -class TestExtraRuntimeCost: - """OrchestratorBlock opts into per-LLM-call billing via extra_runtime_cost.""" - - def test_orchestrator_returns_nonzero_for_multiple_calls(self): - block = OrchestratorBlock() - stats = NodeExecutionStats(llm_call_count=3) - assert block.extra_runtime_cost(stats) == 2 - - def test_orchestrator_returns_zero_for_single_call(self): - block = OrchestratorBlock() - stats = NodeExecutionStats(llm_call_count=1) - assert block.extra_runtime_cost(stats) == 0 - - def test_orchestrator_returns_zero_for_zero_calls(self): - block = OrchestratorBlock() - stats = NodeExecutionStats(llm_call_count=0) - assert block.extra_runtime_cost(stats) == 0 - - def test_default_block_returns_zero(self): - """A block that does not override extra_runtime_cost returns 0.""" - block = _NoOpBlock() - stats = NodeExecutionStats(llm_call_count=10) - assert block.extra_runtime_cost(stats) == 0 - - -# ── charge_extra_runtime_cost math ─────────────────────────────────── - - -@pytest.fixture() -def fake_node_exec(): - node_exec = MagicMock() - node_exec.user_id = "u" - node_exec.graph_exec_id = "g" - node_exec.graph_id = "g" - node_exec.node_exec_id = "ne" - node_exec.node_id = "n" - node_exec.block_id = "b" - node_exec.inputs = {} - return node_exec - - -@pytest.fixture() -def patched_processor(monkeypatch): - """ExecutionProcessor with stubbed db client / block lookup helpers. - - Returns the processor and a list of credit amounts spent so tests can - assert on what was charged. - - Note: ``ExecutionProcessor.__new__()`` bypasses ``__init__`` — if - ``__init__`` gains required state in the future this fixture will need - updating. - """ - spent: list[int] = [] - - class FakeDb: - def spend_credits(self, *, user_id, cost, metadata): - spent.append(cost) - return 1000 # remaining balance - - fake_block = MagicMock() - fake_block.name = "FakeBlock" - - monkeypatch.setattr(billing, "get_db_client", lambda: FakeDb()) - monkeypatch.setattr(billing, "get_block", lambda block_id: fake_block) - monkeypatch.setattr( - billing, - "block_usage_cost", - lambda block, input_data, **_kw: (10, {"model": "claude-sonnet-4-6"}), - ) - - proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) - return proc, spent - - -class TestChargeExtraRuntimeCost: - @pytest.mark.asyncio - async def test_zero_extra_iterations_charges_nothing( - self, patched_processor, fake_node_exec - ): - proc, spent = patched_processor - cost, balance = await proc.charge_extra_runtime_cost( - fake_node_exec, extra_count=0 - ) - assert cost == 0 - assert balance == 0 - assert spent == [] - - @pytest.mark.asyncio - async def test_extra_iterations_multiplies_base_cost( - self, patched_processor, fake_node_exec - ): - proc, spent = patched_processor - cost, balance = await proc.charge_extra_runtime_cost( - fake_node_exec, extra_count=4 - ) - assert cost == 40 # 4 × 10 - assert balance == 1000 - assert spent == [40] - - @pytest.mark.asyncio - async def test_negative_extra_iterations_charges_nothing( - self, patched_processor, fake_node_exec - ): - proc, spent = patched_processor - cost, balance = await proc.charge_extra_runtime_cost( - fake_node_exec, extra_count=-1 - ) - assert cost == 0 - assert balance == 0 - assert spent == [] - - @pytest.mark.asyncio - async def test_capped_at_max(self, monkeypatch, fake_node_exec): - """Runaway llm_call_count is capped at _MAX_EXTRA_RUNTIME_COST.""" - - spent: list[int] = [] - - class FakeDb: - def spend_credits(self, *, user_id, cost, metadata): - spent.append(cost) - return 1000 - - fake_block = MagicMock() - fake_block.name = "FakeBlock" - - monkeypatch.setattr(billing, "get_db_client", lambda: FakeDb()) - monkeypatch.setattr(billing, "get_block", lambda block_id: fake_block) - monkeypatch.setattr( - billing, - "block_usage_cost", - lambda block, input_data, **_kw: (10, {}), - ) - - proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) - cap = billing._MAX_EXTRA_RUNTIME_COST - cost, _ = await proc.charge_extra_runtime_cost( - fake_node_exec, extra_count=cap * 100 - ) - # Charged at most cap × 10 - assert cost == cap * 10 - assert spent == [cap * 10] - - @pytest.mark.asyncio - async def test_zero_base_cost_skips_charge(self, monkeypatch, fake_node_exec): - - spent: list[int] = [] - - class FakeDb: - def spend_credits(self, *, user_id, cost, metadata): - spent.append(cost) - return 0 - - fake_block = MagicMock() - fake_block.name = "FakeBlock" - - monkeypatch.setattr(billing, "get_db_client", lambda: FakeDb()) - monkeypatch.setattr(billing, "get_block", lambda block_id: fake_block) - monkeypatch.setattr( - billing, "block_usage_cost", lambda block, input_data, **_kw: (0, {}) - ) - - proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) - cost, balance = await proc.charge_extra_runtime_cost( - fake_node_exec, extra_count=4 - ) - assert cost == 0 - assert balance == 0 - assert spent == [] - - @pytest.mark.asyncio - async def test_block_not_found_skips_charge(self, monkeypatch, fake_node_exec): - - spent: list[int] = [] - - class FakeDb: - def spend_credits(self, *, user_id, cost, metadata): - spent.append(cost) - return 0 - - monkeypatch.setattr(billing, "get_db_client", lambda: FakeDb()) - monkeypatch.setattr(billing, "get_block", lambda block_id: None) - monkeypatch.setattr( - billing, "block_usage_cost", lambda block, input_data, **_kw: (10, {}) - ) - - proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) - cost, balance = await proc.charge_extra_runtime_cost( - fake_node_exec, extra_count=3 - ) - assert cost == 0 - assert balance == 0 - assert spent == [] - - @pytest.mark.asyncio - async def test_propagates_insufficient_balance_error( - self, monkeypatch, fake_node_exec - ): - """Out-of-credits errors must propagate, not be silently swallowed.""" - - class FakeDb: - def spend_credits(self, *, user_id, cost, metadata): - raise InsufficientBalanceError( - user_id=user_id, - message="Insufficient balance", - balance=0, - amount=cost, - ) - - fake_block = MagicMock() - fake_block.name = "FakeBlock" - - monkeypatch.setattr(billing, "get_db_client", lambda: FakeDb()) - monkeypatch.setattr(billing, "get_block", lambda block_id: fake_block) - monkeypatch.setattr( - billing, "block_usage_cost", lambda block, input_data, **_kw: (10, {}) - ) - - proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) - with pytest.raises(InsufficientBalanceError): - await proc.charge_extra_runtime_cost(fake_node_exec, extra_count=4) - - -# ── charge_node_usage ────────────────────────────────────────────── - - -class TestChargeNodeUsage: - """charge_node_usage delegates to billing.charge_usage with execution_count=0.""" - - @pytest.mark.asyncio - async def test_delegates_with_zero_execution_count( - self, monkeypatch, fake_node_exec - ): - """Nested tool charges should NOT inflate the per-execution counter.""" - - captured: dict = {} - - def fake_charge_usage(node_exec, execution_count): - captured["execution_count"] = execution_count - captured["node_exec"] = node_exec - return (5, 100) - - def fake_handle_low_balance( - db_client, user_id, current_balance, transaction_cost - ): - pass - - monkeypatch.setattr(billing, "charge_usage", fake_charge_usage) - monkeypatch.setattr(billing, "handle_low_balance", fake_handle_low_balance) - monkeypatch.setattr(billing, "get_db_client", lambda: MagicMock()) - - proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) - cost, balance = await proc.charge_node_usage(fake_node_exec) - assert cost == 5 - assert balance == 100 - assert captured["execution_count"] == 0 - - @pytest.mark.asyncio - async def test_calls_handle_low_balance_when_cost_nonzero( - self, monkeypatch, fake_node_exec - ): - """charge_node_usage should call handle_low_balance when total_cost > 0.""" - - low_balance_calls: list[dict] = [] - - def fake_charge_usage(node_exec, execution_count): - return (10, 50) - - def fake_handle_low_balance( - db_client, user_id, current_balance, transaction_cost - ): - low_balance_calls.append( - { - "user_id": user_id, - "current_balance": current_balance, - "transaction_cost": transaction_cost, - } - ) - - monkeypatch.setattr(billing, "charge_usage", fake_charge_usage) - monkeypatch.setattr(billing, "handle_low_balance", fake_handle_low_balance) - monkeypatch.setattr(billing, "get_db_client", lambda: MagicMock()) - - proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) - cost, balance = await proc.charge_node_usage(fake_node_exec) - assert cost == 10 - assert balance == 50 - assert len(low_balance_calls) == 1 - assert low_balance_calls[0]["user_id"] == "u" - assert low_balance_calls[0]["current_balance"] == 50 - assert low_balance_calls[0]["transaction_cost"] == 10 - - @pytest.mark.asyncio - async def test_skips_handle_low_balance_when_cost_zero( - self, monkeypatch, fake_node_exec - ): - """charge_node_usage should NOT call handle_low_balance when cost is 0.""" - - low_balance_calls: list = [] - - def fake_charge_usage(node_exec, execution_count): - return (0, 200) - - def fake_handle_low_balance( - db_client, user_id, current_balance, transaction_cost - ): - low_balance_calls.append(True) - - monkeypatch.setattr(billing, "charge_usage", fake_charge_usage) - monkeypatch.setattr(billing, "handle_low_balance", fake_handle_low_balance) - monkeypatch.setattr(billing, "get_db_client", lambda: MagicMock()) - - proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) - cost, balance = await proc.charge_node_usage(fake_node_exec) - assert cost == 0 - assert low_balance_calls == [] - - -# ── on_node_execution charging gate ──────────────────────────────── - - -class _FakeNode: - """Minimal stand-in for a ``Node`` object with a block attribute.""" - - def __init__(self, extra_charges: int = 0, block_name: str = "FakeBlock"): - self.block = MagicMock() - self.block.name = block_name - self.block.extra_runtime_cost = MagicMock(return_value=extra_charges) - - -class _FakeExecContext: - def __init__(self, dry_run: bool = False): - self.dry_run = dry_run - - -def _make_node_exec(dry_run: bool = False) -> MagicMock: - """Build a NodeExecutionEntry-like mock for on_node_execution tests.""" - ne = MagicMock() - ne.user_id = "u" - ne.graph_id = "g" - ne.graph_exec_id = "ge" - ne.node_id = "n" - ne.node_exec_id = "ne" - ne.block_id = "b" - ne.inputs = {} - ne.execution_context = _FakeExecContext(dry_run=dry_run) - return ne - - -@pytest.fixture() -def gated_processor(monkeypatch): - """ExecutionProcessor with on_node_execution's downstream calls stubbed. - - Lets tests flip the gate conditions (status, extra_runtime_cost result, - llm_call_count, dry_run) and observe whether charge_extra_runtime_cost - was called. - """ - - calls: dict[str, list] = { - "charge_extra_runtime_cost": [], - "handle_low_balance": [], - "handle_insufficient_funds_notif": [], - } - - # Stub node lookup + DB client so the wrapper doesn't touch real infra. - fake_db = MagicMock() - fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=2)) - monkeypatch.setattr(manager, "get_db_async_client", lambda: fake_db) - monkeypatch.setattr(billing, "get_db_client", lambda: fake_db) - # get_block is called by LogMetadata construction in on_node_execution. - monkeypatch.setattr( - manager, - "get_block", - lambda block_id: MagicMock(name="FakeBlock"), - ) - # Persistence + cost logging are not under test here. - monkeypatch.setattr( - manager, - "async_update_node_execution_status", - AsyncMock(return_value=None), - ) - monkeypatch.setattr( - manager, - "async_update_graph_execution_state", - AsyncMock(return_value=None), - ) - monkeypatch.setattr( - manager, - "log_system_credential_cost", - AsyncMock(return_value=None), - ) - - proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) - - # Control the status returned by the inner execution call. - inner_result = {"status": ExecutionStatus.COMPLETED, "llm_call_count": 3} - - async def fake_inner( - self, - *, - node, - node_exec, - node_exec_progress, - stats, - db_client, - log_metadata, - nodes_input_masks=None, - nodes_to_skip=None, - ): - stats.llm_call_count = inner_result["llm_call_count"] - return MagicMock(wall_time=0.1, cpu_time=0.1), inner_result["status"] - - monkeypatch.setattr( - manager.ExecutionProcessor, - "_on_node_execution", - fake_inner, - ) - - async def fake_charge_extra(node_exec, extra_count): - calls["charge_extra_runtime_cost"].append(extra_count) - return (extra_count * 10, 500) - - monkeypatch.setattr(billing, "charge_extra_runtime_cost", fake_charge_extra) - - def fake_low_balance(db_client, user_id, current_balance, transaction_cost): - calls["handle_low_balance"].append( - { - "user_id": user_id, - "current_balance": current_balance, - "transaction_cost": transaction_cost, - } - ) - - monkeypatch.setattr(billing, "handle_low_balance", fake_low_balance) - - def fake_notif(db_client, user_id, graph_id, e): - calls["handle_insufficient_funds_notif"].append( - {"user_id": user_id, "graph_id": graph_id, "error": e} - ) - - monkeypatch.setattr(billing, "handle_insufficient_funds_notif", fake_notif) - - return proc, calls, inner_result, fake_db, NodeExecutionStats - - -@pytest.mark.asyncio -async def test_on_node_execution_charges_extra_iterations_when_gate_passes( - gated_processor, -): - """COMPLETED + extra_runtime_cost > 0 + not dry_run → charged.""" - - proc, calls, inner, fake_db, _ = gated_processor - inner["status"] = ExecutionStatus.COMPLETED - inner["llm_call_count"] = 3 # → extra_charges = 2 - fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=2)) - - stats_pair = ( - MagicMock( - node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0 - ), - threading.Lock(), - ) - await proc.on_node_execution( - node_exec=_make_node_exec(dry_run=False), - node_exec_progress=MagicMock(), - nodes_input_masks=None, - graph_stats_pair=stats_pair, - ) - assert calls["charge_extra_runtime_cost"] == [2] - # handle_low_balance must be called with the remaining balance returned by - # charge_extra_runtime_cost (500) so users are alerted when balance drops low. - assert len(calls["handle_low_balance"]) == 1 - - -@pytest.mark.asyncio -async def test_on_node_execution_skips_when_status_not_completed(gated_processor): - - proc, calls, inner, fake_db, _ = gated_processor - inner["status"] = ExecutionStatus.FAILED - inner["llm_call_count"] = 5 - fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=4)) - - stats_pair = ( - MagicMock( - node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0 - ), - threading.Lock(), - ) - await proc.on_node_execution( - node_exec=_make_node_exec(dry_run=False), - node_exec_progress=MagicMock(), - nodes_input_masks=None, - graph_stats_pair=stats_pair, - ) - assert calls["charge_extra_runtime_cost"] == [] - - -@pytest.mark.asyncio -async def test_on_node_execution_skips_when_extra_charges_zero(gated_processor): - - proc, calls, inner, fake_db, _ = gated_processor - inner["status"] = ExecutionStatus.COMPLETED - inner["llm_call_count"] = 5 - # Block returns 0 extra charges (base class default) - fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=0)) - - stats_pair = ( - MagicMock( - node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0 - ), - threading.Lock(), - ) - await proc.on_node_execution( - node_exec=_make_node_exec(dry_run=False), - node_exec_progress=MagicMock(), - nodes_input_masks=None, - graph_stats_pair=stats_pair, - ) - assert calls["charge_extra_runtime_cost"] == [] - - -@pytest.mark.asyncio -async def test_on_node_execution_skips_when_dry_run(gated_processor): - - proc, calls, inner, fake_db, _ = gated_processor - inner["status"] = ExecutionStatus.COMPLETED - inner["llm_call_count"] = 5 - fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=4)) - - stats_pair = ( - MagicMock( - node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0 - ), - threading.Lock(), - ) - await proc.on_node_execution( - node_exec=_make_node_exec(dry_run=True), - node_exec_progress=MagicMock(), - nodes_input_masks=None, - graph_stats_pair=stats_pair, - ) - assert calls["charge_extra_runtime_cost"] == [] - - -@pytest.mark.asyncio -async def test_on_node_execution_insufficient_balance_records_error_and_notifies( - monkeypatch, - gated_processor, -): - """When extra-iteration charging fails with InsufficientBalanceError: - - - the run still reports COMPLETED (the work is already done) - - execution_stats.error is NOT set (would flip node_error_count and - leak balance amounts into persisted node_stats — see manager.py - comment in the IBE handler) - - _handle_insufficient_funds_notif is called so the user is notified - - the structured ERROR log is the alerting hook - """ - - proc, calls, inner, fake_db, _ = gated_processor - inner["status"] = ExecutionStatus.COMPLETED - inner["llm_call_count"] = 4 - fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=3)) - - async def raise_ibe(node_exec, extra_count): - raise InsufficientBalanceError( - user_id=node_exec.user_id, - message="Insufficient balance", - balance=0, - amount=extra_count * 10, - ) - - monkeypatch.setattr(billing, "charge_extra_runtime_cost", raise_ibe) - - stats_pair = ( - MagicMock( - node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0 - ), - threading.Lock(), - ) - result_stats = await proc.on_node_execution( - node_exec=_make_node_exec(dry_run=False), - node_exec_progress=MagicMock(), - nodes_input_masks=None, - graph_stats_pair=stats_pair, - ) - # error stays None — node ran to completion, only the post-hoc - # charge failed. Setting .error would (a) flip node_error_count++ - # creating an "errored COMPLETED node" inconsistency, and (b) leak - # balance amounts into persisted node_stats. - assert result_stats.error is None - # User notification fired. - assert len(calls["handle_insufficient_funds_notif"]) == 1 - assert calls["handle_insufficient_funds_notif"][0]["user_id"] == "u" - - -# ── Orchestrator _execute_single_tool_with_manager charging gates ── - - -async def _run_tool_exec_with_stats( - *, - dry_run: bool, - tool_stats_error, - charge_node_usage_mock=None, -): - """Invoke _execute_single_tool_with_manager against fully mocked deps - and return (charge_call_count, merge_stats_calls). - - Used to prove the dry_run and error guards around charge_node_usage - behave as documented, and that InsufficientBalanceError propagates. - """ - block = OrchestratorBlock() - - # Mocked async DB client used inside orchestrator. - mock_db_client = AsyncMock() - mock_target_node = MagicMock() - mock_target_node.block_id = "test-block-id" - mock_target_node.input_default = {} - mock_db_client.get_node.return_value = mock_target_node - mock_node_exec_result = MagicMock() - mock_node_exec_result.node_exec_id = "test-tool-exec-id" - mock_db_client.upsert_execution_input.return_value = ( - mock_node_exec_result, - {"query": "t"}, - ) - mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {"result": "ok"} - - # ExecutionProcessor mock: on_node_execution returns supplied error. - mock_processor = AsyncMock() - mock_processor.running_node_execution = defaultdict(MagicMock) - mock_processor.execution_stats = MagicMock() - mock_processor.execution_stats_lock = threading.Lock() - mock_node_stats = MagicMock() - mock_node_stats.error = tool_stats_error - mock_processor.on_node_execution = AsyncMock(return_value=mock_node_stats) - mock_processor.charge_node_usage = charge_node_usage_mock or AsyncMock( - return_value=(10, 990) - ) - - # Build a tool_info shaped like _build_tool_info_from_args output. - tool_call = MagicMock() - tool_call.id = "call-1" - tool_call.name = "search_keywords" - tool_call.arguments = '{"query":"t"}' - tool_def = { - "type": "function", - "function": { - "name": "search_keywords", - "_sink_node_id": "test-sink-node-id", - "_field_mapping": {}, - "parameters": { - "properties": {"query": {"type": "string"}}, - "required": ["query"], - }, - }, - } - tool_info = OrchestratorBlock._build_tool_info_from_args( - tool_call_id="call-1", - tool_name="search_keywords", - tool_args={"query": "t"}, - tool_def=tool_def, - ) - - exec_params = ExecutionParams( - user_id="u", - graph_id="g", - node_id="n", - graph_version=1, - graph_exec_id="ge", - node_exec_id="ne", - execution_context=ExecutionContext( - human_in_the_loop_safe_mode=False, dry_run=dry_run - ), - ) - - with patch( - "backend.blocks.orchestrator.get_database_manager_async_client", - return_value=mock_db_client, - ): - try: - await block._execute_single_tool_with_manager( - tool_info, exec_params, mock_processor, responses_api=False - ) - raised = None - except Exception as e: - raised = e - - return mock_processor.charge_node_usage, raised - - -@pytest.mark.asyncio -async def test_tool_execution_skips_charging_on_dry_run(): - """dry_run=True → charge_node_usage is NOT called.""" - charge_mock, raised = await _run_tool_exec_with_stats( - dry_run=True, tool_stats_error=None - ) - assert raised is None - assert charge_mock.call_count == 0 - - -@pytest.mark.asyncio -async def test_tool_execution_skips_charging_on_failed_tool(): - """tool_node_stats.error is an Exception → charge_node_usage NOT called.""" - charge_mock, raised = await _run_tool_exec_with_stats( - dry_run=False, tool_stats_error=RuntimeError("tool blew up") - ) - assert raised is None - assert charge_mock.call_count == 0 - - -@pytest.mark.asyncio -async def test_tool_execution_skips_charging_on_cancelled_tool(): - """Cancellation (BaseException subclass) → charge_node_usage NOT called. - - Guards the fix for sentry's BaseException concern: the old - `isinstance(error, Exception)` check would have treated CancelledError - as "no error" and billed the user for a terminated run. - """ - import asyncio as _asyncio - - charge_mock, raised = await _run_tool_exec_with_stats( - dry_run=False, tool_stats_error=_asyncio.CancelledError() - ) - assert raised is None - assert charge_mock.call_count == 0 - - -@pytest.mark.asyncio -async def test_tool_execution_insufficient_balance_propagates(): - """InsufficientBalanceError from charge_node_usage must propagate out. - - If this leaked into a ToolCallResult the LLM loop would keep running - with 'tool failed' errors and the user would get unpaid work. - """ - raising_charge = AsyncMock( - side_effect=InsufficientBalanceError( - user_id="u", message="nope", balance=0, amount=10 - ) - ) - _, raised = await _run_tool_exec_with_stats( - dry_run=False, - tool_stats_error=None, - charge_node_usage_mock=raising_charge, - ) - assert isinstance(raised, InsufficientBalanceError) - - -@pytest.mark.asyncio -async def test_tool_execution_on_node_execution_returns_none_sets_is_error(): - """on_node_execution returning None (swallowed by @async_error_logged) must - result in a tool response with _is_error=True so the LLM loop knows the - tool failed and does not treat a silent error as a successful execution. - """ - block = OrchestratorBlock() - - mock_db_client = AsyncMock() - mock_target_node = MagicMock() - mock_target_node.block_id = "test-block-id" - mock_target_node.input_default = {} - mock_db_client.get_node.return_value = mock_target_node - mock_node_exec_result = MagicMock() - mock_node_exec_result.node_exec_id = "test-tool-exec-id" - mock_db_client.upsert_execution_input.return_value = ( - mock_node_exec_result, - {"query": "t"}, - ) - - mock_processor = AsyncMock() - mock_processor.running_node_execution = defaultdict(MagicMock) - mock_processor.execution_stats = MagicMock() - mock_processor.execution_stats_lock = threading.Lock() - # on_node_execution returns None — simulates @async_error_logged(swallow=True) - # swallowing an internal error - mock_processor.on_node_execution = AsyncMock(return_value=None) - - tool_call = MagicMock() - tool_call.id = "call-none" - tool_call.name = "search_keywords" - tool_call.arguments = '{"query":"t"}' - tool_def = { - "type": "function", - "function": { - "name": "search_keywords", - "_sink_node_id": "test-sink-node-id", - "_field_mapping": {}, - "parameters": { - "properties": {"query": {"type": "string"}}, - "required": ["query"], - }, - }, - } - tool_info = OrchestratorBlock._build_tool_info_from_args( - tool_call_id="call-none", - tool_name="search_keywords", - tool_args={"query": "t"}, - tool_def=tool_def, - ) - - exec_params = ExecutionParams( - user_id="u", - graph_id="g", - node_id="n", - graph_version=1, - graph_exec_id="ge", - node_exec_id="ne", - execution_context=ExecutionContext( - human_in_the_loop_safe_mode=False, dry_run=False - ), - ) - - with patch( - "backend.blocks.orchestrator.get_database_manager_async_client", - return_value=mock_db_client, - ): - resp = await block._execute_single_tool_with_manager( - tool_info, exec_params, mock_processor, responses_api=False - ) - - assert resp.get("_is_error") is True - # charge_node_usage must NOT be called for a failed tool execution - mock_processor.charge_node_usage.assert_not_called() - - -# ── on_node_execution FAILED + InsufficientBalanceError notification ── - - -@pytest.mark.asyncio -async def test_on_node_execution_failed_ibe_sends_notification( - monkeypatch, - gated_processor, -): - """When status == FAILED and execution_stats.error is InsufficientBalanceError, - _handle_insufficient_funds_notif must be called. - - This path fires when a nested tool charge inside the orchestrator raises - InsufficientBalanceError, which propagates out of the block's run() generator - and is caught by _on_node_execution's broad except, setting status=FAILED and - execution_stats.error=IBE. on_node_execution's post-execution block then - sends the user notification so they understand why the run stopped. - """ - - proc, calls, inner, fake_db, NodeExecutionStats = gated_processor - ibe = InsufficientBalanceError( - user_id="u", - message="Insufficient balance", - balance=0, - amount=30, - ) - - # Simulate _on_node_execution returning FAILED with IBE in stats.error. - async def fake_inner_failed( - self, - *, - node, - node_exec, - node_exec_progress, - stats, - db_client, - log_metadata, - nodes_input_masks=None, - nodes_to_skip=None, - ): - stats.error = ibe - return MagicMock(wall_time=0.1, cpu_time=0.1), ExecutionStatus.FAILED - - monkeypatch.setattr( - manager.ExecutionProcessor, - "_on_node_execution", - fake_inner_failed, - ) - fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=0)) - - stats_pair = ( - MagicMock( - node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0 - ), - threading.Lock(), - ) - await proc.on_node_execution( - node_exec=_make_node_exec(dry_run=False), - node_exec_progress=MagicMock(), - nodes_input_masks=None, - graph_stats_pair=stats_pair, - ) - # The notification must have fired so the user knows why their run stopped. - assert len(calls["handle_insufficient_funds_notif"]) == 1 - assert calls["handle_insufficient_funds_notif"][0]["user_id"] == "u" - # charge_extra_runtime_cost must NOT be called — status is FAILED. - assert calls["charge_extra_runtime_cost"] == [] - - -# ── Billing leak: non-IBE exception during extra-iteration charging ── - - -@pytest.mark.asyncio -async def test_on_node_execution_non_ibe_billing_failure_keeps_completed( - monkeypatch, - gated_processor, -): - """When charge_extra_runtime_cost raises a non-IBE exception (e.g. DB outage): - - - execution_stats.error stays None (node ran to completion) - - status stays COMPLETED (work already done) - - the billing_leak error is logged but does not corrupt execution_stats - """ - proc, calls, inner, fake_db, _ = gated_processor - inner["status"] = ExecutionStatus.COMPLETED - inner["llm_call_count"] = 4 - fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=3)) - - async def raise_conn_error(node_exec, extra_count): - raise ConnectionError("DB connection lost") - - monkeypatch.setattr(billing, "charge_extra_runtime_cost", raise_conn_error) - - stats_pair = ( - MagicMock( - node_count=0, - nodes_cputime=0, - nodes_walltime=0, - cost=0, - node_error_count=0, - ), - threading.Lock(), - ) - result_stats = await proc.on_node_execution( - node_exec=_make_node_exec(dry_run=False), - node_exec_progress=MagicMock(), - nodes_input_masks=None, - graph_stats_pair=stats_pair, - ) - # error stays None — node completed, only billing failed. - assert result_stats.error is None - # No notification was sent (only IBE triggers notification). - assert len(calls["handle_insufficient_funds_notif"]) == 0 - - -# ── _charge_usage with execution_count=0 ── - - -class TestChargeUsageZeroExecutionCount: - """Verify _charge_usage(node_exec, 0) does not invoke execution_usage_cost.""" - - def test_execution_count_zero_skips_execution_tier(self, monkeypatch): - """_charge_usage with execution_count=0 must not call execution_usage_cost.""" - execution_tier_called = [] - - def fake_execution_usage_cost(count): - execution_tier_called.append(count) - return (100, count) - - spent: list[int] = [] - - class FakeDb: - def spend_credits(self, *, user_id, cost, metadata): - spent.append(cost) - return 500 - - fake_block = MagicMock() - fake_block.name = "FakeBlock" - - monkeypatch.setattr(billing, "get_db_client", lambda: FakeDb()) - monkeypatch.setattr(billing, "get_block", lambda block_id: fake_block) - monkeypatch.setattr( - billing, - "block_usage_cost", - lambda block, input_data, **_kw: (10, {}), - ) - monkeypatch.setattr(billing, "execution_usage_cost", fake_execution_usage_cost) - - ne = MagicMock() - ne.user_id = "u" - ne.graph_exec_id = "ge" - ne.graph_id = "g" - ne.node_exec_id = "ne" - ne.node_id = "n" - ne.block_id = "b" - ne.inputs = {} - - total_cost, remaining = billing.charge_usage(ne, 0) - assert total_cost == 10 # block cost only - assert remaining == 500 - assert spent == [10] - # execution_usage_cost must NOT have been called - assert execution_tier_called == [] diff --git a/autogpt_platform/backend/backend/blocks/video/narration.py b/autogpt_platform/backend/backend/blocks/video/narration.py index 39b9c481b0..ed3835ec03 100644 --- a/autogpt_platform/backend/backend/blocks/video/narration.py +++ b/autogpt_platform/backend/backend/blocks/video/narration.py @@ -27,7 +27,7 @@ from backend.blocks.video._utils import ( strip_chapters_inplace, ) from backend.data.execution import ExecutionContext -from backend.data.model import CredentialsField, SchemaField +from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField from backend.util.exceptions import BlockExecutionError from backend.util.file import MediaFileType, get_exec_file_path, store_media_file @@ -44,7 +44,8 @@ class VideoNarrationBlock(Block): ) script: str = SchemaField(description="Narration script text") voice_id: str = SchemaField( - description="ElevenLabs voice ID", default="21m00Tcm4TlvDq8ikWAM" # Rachel + description="ElevenLabs voice ID", + default="21m00Tcm4TlvDq8ikWAM", # Rachel ) model_id: Literal[ "eleven_multilingual_v2", @@ -124,6 +125,26 @@ class VideoNarrationBlock(Block): return_format="for_block_output", ) + # Models that consume 0.5 credits per character (v2.5 tier). All other + # models default to 1.0 credit per character. + _HALF_RATE_MODELS = {"eleven_flash_v2_5", "eleven_turbo_v2_5"} + # ElevenLabs Starter plan: $5 / 30K credits = $0.000167 / credit. + _USD_PER_CREDIT = 0.000167 + + def _record_script_cost(self, script: str, model_id: str) -> None: + """Emit provider_cost (USD) for the narration run so the COST_USD + resolver can bill real ElevenLabs spend. Flash/Turbo v2.5 bill at + half the char rate of Multilingual/Turbo v2. + """ + credits_per_char = 0.5 if model_id in self._HALF_RATE_MODELS else 1.0 + script_usd = len(script) * self._USD_PER_CREDIT * credits_per_char + self.merge_stats( + NodeExecutionStats( + provider_cost=script_usd, + provider_cost_type="cost_usd", + ) + ) + def _generate_narration_audio( self, api_key: str, script: str, voice_id: str, model_id: str ) -> bytes: @@ -223,6 +244,8 @@ class VideoNarrationBlock(Block): input_data.model_id, ) + self._record_script_cost(input_data.script, input_data.model_id) + # Save audio to exec file path audio_filename = MediaFileType(f"{node_exec_id}_narration.mp3") audio_abspath = get_exec_file_path( diff --git a/autogpt_platform/backend/backend/blocks/zerobounce/validate_emails.py b/autogpt_platform/backend/backend/blocks/zerobounce/validate_emails.py index 6a461b4aa8..57c0ed3ef7 100644 --- a/autogpt_platform/backend/backend/blocks/zerobounce/validate_emails.py +++ b/autogpt_platform/backend/backend/blocks/zerobounce/validate_emails.py @@ -21,7 +21,7 @@ from backend.blocks.zerobounce._auth import ( ZeroBounceCredentials, ZeroBounceCredentialsInput, ) -from backend.data.model import CredentialsField, SchemaField +from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField class Response(BaseModel): @@ -140,20 +140,22 @@ class ValidateEmailsBlock(Block): ) ], test_mock={ - "validate_email": lambda email, ip_address, credentials: ZBValidateResponse( - data={ - "address": email, - "status": ZBValidateStatus.valid, - "sub_status": ZBValidateSubStatus.allowed, - "account": "test", - "domain": "test.com", - "did_you_mean": None, - "domain_age_days": None, - "free_email": False, - "mx_found": False, - "mx_record": None, - "smtp_provider": None, - } + "validate_email": lambda email, ip_address, credentials: ( + ZBValidateResponse( + data={ + "address": email, + "status": ZBValidateStatus.valid, + "sub_status": ZBValidateSubStatus.allowed, + "account": "test", + "domain": "test.com", + "did_you_mean": None, + "domain_age_days": None, + "free_email": False, + "mx_found": False, + "mx_record": None, + "smtp_provider": None, + } + ) ) }, ) @@ -176,6 +178,13 @@ class ValidateEmailsBlock(Block): input_data.email, input_data.ip_address, credentials ) + # ZeroBounce bills $0.008 per validated email on the paid tier. + # Routed through COST_USD so platform cost telemetry captures real + # USD spend; the resolver still bills 2 credits per call. + self.merge_stats( + NodeExecutionStats(provider_cost=0.008, provider_cost_type="cost_usd") + ) + response_model = Response(**response.__dict__) yield "response", response_model diff --git a/autogpt_platform/backend/backend/copilot/baseline/reasoning.py b/autogpt_platform/backend/backend/copilot/baseline/reasoning.py new file mode 100644 index 0000000000..1d0da8ce7e --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/baseline/reasoning.py @@ -0,0 +1,364 @@ +"""Extended-thinking wire support for the baseline (OpenRouter) path. + +OpenRouter routes that support extended thinking (Anthropic Claude and +Moonshot Kimi today) expose reasoning through non-OpenAI extension fields +that the OpenAI Python SDK doesn't model: + +* ``reasoning`` (legacy string) — enabled by ``include_reasoning: true``. +* ``reasoning_content`` — DeepSeek / some OpenRouter routes. +* ``reasoning_details`` — structured list shipped with the unified + ``reasoning`` request param. + +This module keeps the wire-level concerns in one place: + +* :class:`OpenRouterDeltaExtension` validates the extension dict pulled off + ``ChoiceDelta.model_extra`` into typed pydantic models — no ``getattr`` + + ``isinstance`` duck-typing at the call site. +* :class:`BaselineReasoningEmitter` owns the reasoning block lifecycle for + one streaming round and emits ``StreamReasoning*`` events so the caller + only has to plumb the events into its pending queue. +* :func:`reasoning_extra_body` builds the ``extra_body`` fragment for the + OpenAI client call. Returns ``None`` for routes without reasoning + support (see :func:`_is_reasoning_route`). +""" + +from __future__ import annotations + +import logging +import time +import uuid +from typing import Any + +from openai.types.chat.chat_completion_chunk import ChoiceDelta +from pydantic import BaseModel, ConfigDict, Field, ValidationError + +from backend.copilot.model import ChatMessage +from backend.copilot.response_model import ( + StreamBaseResponse, + StreamReasoningDelta, + StreamReasoningEnd, + StreamReasoningStart, +) + +logger = logging.getLogger(__name__) + + +_VISIBLE_REASONING_TYPES = frozenset({"reasoning.text", "reasoning.summary"}) + +# Coalescing thresholds for ``StreamReasoningDelta`` emission. OpenRouter's +# Kimi K2.6 endpoint tokenises reasoning at a much finer grain than Anthropic +# (~4,700 deltas per turn in one observed session, vs ~28 for Sonnet); without +# coalescing, every chunk is one Redis ``xadd`` + one SSE frame + one React +# re-render of the non-virtualised chat list, which paint-storms the browser +# main thread and freezes the UI. Batching into ~64-char / ~50 ms windows +# cuts the event rate ~150x while staying snappy enough that the Reasoning +# collapse still feels live (well under the ~100 ms perceptual threshold). +# Per-delta persistence to ``session.messages`` stays granular — we only +# coalesce the *wire* emission. +_COALESCE_MIN_CHARS = 64 +_COALESCE_MAX_INTERVAL_MS = 50.0 + + +class ReasoningDetail(BaseModel): + """One entry in OpenRouter's ``reasoning_details`` list. + + OpenRouter ships ``type: "reasoning.text"`` / ``"reasoning.summary"`` / + ``"reasoning.encrypted"`` entries. Only the first two carry + user-visible text; encrypted entries are opaque and omitted from the + rendered collapse. Unknown future types are tolerated (``extra="ignore"``) + so an upstream addition doesn't crash the stream — but their ``text`` / + ``summary`` fields are NOT surfaced because they may carry provider + metadata rather than user-visible reasoning (see + :attr:`visible_text`). + """ + + model_config = ConfigDict(extra="ignore") + + type: str | None = None + text: str | None = None + summary: str | None = None + + @property + def visible_text(self) -> str: + """Return the human-readable text for this entry, or ``""``. + + Only entries with a recognised reasoning type (``reasoning.text`` / + ``reasoning.summary``) surface text; unknown or encrypted types + return an empty string even if they carry a ``text`` / + ``summary`` field, to guard against future provider metadata + being rendered as reasoning in the UI. Entries missing a + ``type`` are treated as text (pre-``reasoning_details`` OpenRouter + payloads omit the field). + """ + if self.type is not None and self.type not in _VISIBLE_REASONING_TYPES: + return "" + return self.text or self.summary or "" + + +class OpenRouterDeltaExtension(BaseModel): + """Non-OpenAI fields OpenRouter adds to streaming deltas. + + Instantiate via :meth:`from_delta` which pulls the extension dict off + ``ChoiceDelta.model_extra`` (where pydantic v2 stashes fields that + aren't part of the declared schema) and validates it through this + model. That keeps the parser honest — malformed entries surface as + validation errors rather than silent ``None``-coalesce bugs — and + avoids the ``getattr`` + ``isinstance`` duck-typing the earlier inline + extractor relied on. + """ + + model_config = ConfigDict(extra="ignore") + + reasoning: str | None = None + reasoning_content: str | None = None + reasoning_details: list[ReasoningDetail] = Field(default_factory=list) + + @classmethod + def from_delta(cls, delta: ChoiceDelta) -> "OpenRouterDeltaExtension": + """Build an extension view from ``delta.model_extra``. + + Malformed provider payloads (e.g. ``reasoning_details`` shipped as + a string rather than a list) surface as a ``ValidationError`` which + is logged and swallowed — returning an empty extension so the rest + of the stream (valid text / tool calls) keeps flowing. An optional + feature's corrupted wire data must never abort the whole stream. + """ + try: + return cls.model_validate(delta.model_extra or {}) + except ValidationError as exc: + logger.warning( + "[Baseline] Dropping malformed OpenRouter reasoning payload: %s", + exc, + ) + return cls() + + def visible_text(self) -> str: + """Concatenated reasoning text, pulled from whichever channel is set. + + Priority: the legacy ``reasoning`` string, then DeepSeek's + ``reasoning_content``, then the concatenation of text-bearing + entries in ``reasoning_details``. Only one channel is set per + provider in practice; the priority order just makes the fallback + deterministic if a provider ever emits multiple. + """ + if self.reasoning: + return self.reasoning + if self.reasoning_content: + return self.reasoning_content + return "".join(d.visible_text for d in self.reasoning_details) + + +def _is_reasoning_route(model: str) -> bool: + """Return True when the route supports OpenRouter's ``reasoning`` extension. + + OpenRouter exposes reasoning tokens via a unified ``reasoning`` request + param that works on any provider that supports extended thinking — + currently Anthropic (Claude Opus / Sonnet) and Moonshot (Kimi K2.6 + + kimi-k2-thinking) advertise it in their ``supported_parameters``. + Other providers silently drop the field, but we skip it anyway to keep + the payload tight and avoid confusing cache diagnostics. + + Kept separate from :func:`backend.copilot.baseline.service._is_anthropic_model` + because ``cache_control`` is strictly Anthropic-specific (Moonshot does + its own auto-caching), so the two gates must not conflate. + + Both the Claude and Kimi matches are anchored to the provider + prefix (or to a bare model id with no prefix at all) to avoid + substring false positives — a custom ``some-other-provider/claude-mock`` + or ``provider/hakimi-large`` configured via + ``CHAT_FAST_STANDARD_MODEL`` must NOT inherit the reasoning + extra_body and take a 400 from its upstream. Recognised shapes: + + * Claude — ``anthropic/`` or ``anthropic.`` provider prefix, or a + bare ``claude-`` model id with no provider prefix + (``claude-opus-4.7``, ``anthropic/claude-sonnet-4-6``, + ``anthropic.claude-3-5-sonnet``). A non-Anthropic prefix like + ``someprovider/claude-mock`` is rejected on purpose. + * Kimi — ``moonshotai/`` provider prefix, or a ``kimi-`` model id + with no provider prefix (``kimi-k2.6``, + ``moonshotai/kimi-k2-thinking``). Like Claude, a non-Moonshot + prefix is rejected — exception: ``openrouter/kimi-k2.6`` stays + recognised because ``openrouter/`` is how we route to Moonshot + today and changing that would be a behaviour regression for + existing deployments. + """ + lowered = model.lower() + if lowered.startswith(("anthropic/", "anthropic.")): + return True + if lowered.startswith("moonshotai/"): + return True + # ``openrouter/`` historically routes to whatever the default + # upstream for the model is — for kimi that's Moonshot, so accept + # ``openrouter/kimi-...`` here. Other ``openrouter/`` models + # (e.g. ``openrouter/auto``) fall through to the no-prefix check + # below and are rejected unless they start with ``claude-`` / + # ``kimi-`` after the slash, which no real OpenRouter route does. + if lowered.startswith("openrouter/kimi-"): + return True + if "/" in lowered: + # Any other provider prefix is a custom / non-Anthropic / + # non-Moonshot route and must not opt into reasoning. This + # blocks substring false positives like + # ``some-provider/claude-mock-v1`` or ``other/kimi-pro``. + return False + # No provider prefix — accept bare ``claude-*`` and ``kimi-*`` ids + # so direct CLI configs (``claude-3-5-sonnet-20241022``, + # ``kimi-k2-instruct``) keep working. + return lowered.startswith("claude-") or lowered.startswith("kimi-") + + +def reasoning_extra_body(model: str, max_thinking_tokens: int) -> dict[str, Any] | None: + """Build the ``extra_body["reasoning"]`` fragment for the OpenAI client. + + Returns ``None`` for non-reasoning routes and for + ``max_thinking_tokens <= 0`` (operator kill switch). + """ + if not _is_reasoning_route(model) or max_thinking_tokens <= 0: + return None + return {"reasoning": {"max_tokens": max_thinking_tokens}} + + +class BaselineReasoningEmitter: + """Owns the reasoning block lifecycle for one streaming round. + + Two concerns live here, both driven by the same state machine: + + 1. **Wire events.** The AI SDK v6 wire format pairs every + ``reasoning-start`` with a matching ``reasoning-end`` and treats + reasoning / text / tool-use as distinct UI parts that must not + interleave. + 2. **Session persistence.** ``ChatMessage(role="reasoning")`` rows in + ``session.messages`` are what + ``convertChatSessionToUiMessages.ts`` folds into the assistant + bubble as ``{type: "reasoning"}`` UI parts on reload and on + ``useHydrateOnStreamEnd`` swaps. Without them the live-streamed + reasoning parts get overwritten by the hydrated (reasoning-less) + message list the moment the stream ends. Mirrors the SDK path's + ``acc.reasoning_response`` pattern so both routes render the same + way on reload. + + Pass ``session_messages`` to enable persistence; omit for pure + wire-emission (tests, scratch callers). On first reasoning delta a + fresh ``ChatMessage(role="reasoning")`` is appended and mutated + in-place as further deltas arrive; :meth:`close` drops the reference + but leaves the appended row intact. + + ``render_in_ui=False`` suppresses only the live wire events + (``StreamReasoning*``); the ``role='reasoning'`` persistence row is + still appended so ``convertChatSessionToUiMessages.ts`` can hydrate + the reasoning bubble on reload. The state machine advances + identically either way. + """ + + def __init__( + self, + session_messages: list[ChatMessage] | None = None, + *, + coalesce_min_chars: int = _COALESCE_MIN_CHARS, + coalesce_max_interval_ms: float = _COALESCE_MAX_INTERVAL_MS, + render_in_ui: bool = True, + ) -> None: + self._block_id: str = str(uuid.uuid4()) + self._open: bool = False + self._session_messages = session_messages + self._current_row: ChatMessage | None = None + # Coalescing state — tests can disable (``=0``) for deterministic + # event assertions. + self._coalesce_min_chars = coalesce_min_chars + self._coalesce_max_interval_ms = coalesce_max_interval_ms + self._pending_delta: str = "" + self._last_flush_monotonic: float = 0.0 + self._render_in_ui = render_in_ui + + @property + def is_open(self) -> bool: + return self._open + + def on_delta(self, delta: ChoiceDelta) -> list[StreamBaseResponse]: + """Return events for the reasoning text carried by *delta*. + + Empty list when the chunk carries no reasoning payload, so this is + safe to call on every chunk without guarding at the call site. + + Persistence (when a session message list is attached) stays + per-delta so the DB row's content always equals the concatenation + of wire deltas at every chunk boundary, independent of the + coalescing window. Only the wire emission is batched. + """ + ext = OpenRouterDeltaExtension.from_delta(delta) + text = ext.visible_text() + if not text: + return [] + events: list[StreamBaseResponse] = [] + # First reasoning text in this block — emit Start + the first Delta + # atomically so the frontend Reasoning collapse renders immediately + # rather than waiting for the coalesce window to elapse. Subsequent + # chunks buffer into ``_pending_delta`` and only flush when the + # char/time thresholds trip. + # Sample the monotonic clock exactly once per chunk — at ~4,700 + # chunks per turn, folding the two calls into one cuts ~4,700 + # syscalls off the hot path without changing semantics. + now = time.monotonic() + if not self._open: + if self._render_in_ui: + events.append(StreamReasoningStart(id=self._block_id)) + events.append(StreamReasoningDelta(id=self._block_id, delta=text)) + self._open = True + self._last_flush_monotonic = now + if self._session_messages is not None: + self._current_row = ChatMessage(role="reasoning", content=text) + self._session_messages.append(self._current_row) + return events + + if self._current_row is not None: + self._current_row.content = (self._current_row.content or "") + text + + self._pending_delta += text + if self._should_flush_pending(now): + if self._render_in_ui: + events.append( + StreamReasoningDelta(id=self._block_id, delta=self._pending_delta) + ) + self._pending_delta = "" + self._last_flush_monotonic = now + return events + + def _should_flush_pending(self, now: float) -> bool: + """Return True when the accumulated delta should be emitted now. + + *now* is the monotonic timestamp sampled by the caller so the + clock is read at most once per chunk (the flush-timestamp update + reuses the same value). + """ + if not self._pending_delta: + return False + if len(self._pending_delta) >= self._coalesce_min_chars: + return True + elapsed_ms = (now - self._last_flush_monotonic) * 1000.0 + return elapsed_ms >= self._coalesce_max_interval_ms + + def close(self) -> list[StreamBaseResponse]: + """Emit ``StreamReasoningEnd`` for the open block (if any) and rotate. + + Idempotent — returns ``[]`` when no block is open. Drains any + still-buffered delta first so the frontend never loses tail text + from the coalesce window. The id rotation guarantees the next + reasoning block starts with a fresh id rather than reusing one + already closed on the wire. The persisted row is not removed — + it stays in ``session_messages`` as the durable record of what + was reasoned. + """ + if not self._open: + return [] + events: list[StreamBaseResponse] = [] + if self._render_in_ui: + if self._pending_delta: + events.append( + StreamReasoningDelta(id=self._block_id, delta=self._pending_delta) + ) + events.append(StreamReasoningEnd(id=self._block_id)) + self._pending_delta = "" + self._open = False + self._block_id = str(uuid.uuid4()) + self._current_row = None + return events diff --git a/autogpt_platform/backend/backend/copilot/baseline/reasoning_test.py b/autogpt_platform/backend/backend/copilot/baseline/reasoning_test.py new file mode 100644 index 0000000000..1f5ca01845 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/baseline/reasoning_test.py @@ -0,0 +1,514 @@ +"""Tests for the baseline reasoning extension module. + +Covers the typed OpenRouter delta parser, the stateful emitter, and the +``extra_body`` builder. The emitter is tested against real +``ChoiceDelta`` pydantic instances so the ``model_extra`` plumbing the +parser relies on is exercised end-to-end. +""" + +from openai.types.chat.chat_completion_chunk import ChoiceDelta + +from backend.copilot.baseline.reasoning import ( + BaselineReasoningEmitter, + OpenRouterDeltaExtension, + ReasoningDetail, + _is_reasoning_route, + reasoning_extra_body, +) +from backend.copilot.model import ChatMessage +from backend.copilot.response_model import ( + StreamReasoningDelta, + StreamReasoningEnd, + StreamReasoningStart, +) + + +def _delta(**extra) -> ChoiceDelta: + """Build a ChoiceDelta with the given extension fields on ``model_extra``.""" + return ChoiceDelta.model_validate({"role": "assistant", **extra}) + + +class TestReasoningDetail: + def test_visible_text_prefers_text(self): + d = ReasoningDetail(type="reasoning.text", text="hi", summary="ignored") + assert d.visible_text == "hi" + + def test_visible_text_falls_back_to_summary(self): + d = ReasoningDetail(type="reasoning.summary", summary="tldr") + assert d.visible_text == "tldr" + + def test_visible_text_empty_for_encrypted(self): + d = ReasoningDetail(type="reasoning.encrypted") + assert d.visible_text == "" + + def test_unknown_fields_are_ignored(self): + # OpenRouter may add new fields in future payloads — they shouldn't + # cause validation errors. + d = ReasoningDetail.model_validate( + {"type": "reasoning.future", "text": "x", "signature": "opaque"} + ) + assert d.text == "x" + + def test_visible_text_empty_for_unknown_type(self): + # Unknown types may carry provider metadata that must not render as + # user-visible reasoning — regardless of whether a text/summary is + # present. Only ``reasoning.text`` / ``reasoning.summary`` surface. + d = ReasoningDetail(type="reasoning.future", text="leaked metadata") + assert d.visible_text == "" + + def test_visible_text_surfaces_text_when_type_missing(self): + # Pre-``reasoning_details`` OpenRouter payloads omit ``type`` — treat + # them as text so we don't regress the legacy structured shape. + d = ReasoningDetail(text="plain") + assert d.visible_text == "plain" + + +class TestOpenRouterDeltaExtension: + def test_from_delta_reads_model_extra(self): + delta = _delta(reasoning="step one") + ext = OpenRouterDeltaExtension.from_delta(delta) + assert ext.reasoning == "step one" + + def test_visible_text_legacy_string(self): + ext = OpenRouterDeltaExtension(reasoning="plain text") + assert ext.visible_text() == "plain text" + + def test_visible_text_deepseek_alias(self): + ext = OpenRouterDeltaExtension(reasoning_content="alt channel") + assert ext.visible_text() == "alt channel" + + def test_visible_text_structured_details_concat(self): + ext = OpenRouterDeltaExtension( + reasoning_details=[ + ReasoningDetail(type="reasoning.text", text="hello "), + ReasoningDetail(type="reasoning.text", text="world"), + ] + ) + assert ext.visible_text() == "hello world" + + def test_visible_text_skips_encrypted(self): + ext = OpenRouterDeltaExtension( + reasoning_details=[ + ReasoningDetail(type="reasoning.encrypted"), + ReasoningDetail(type="reasoning.text", text="visible"), + ] + ) + assert ext.visible_text() == "visible" + + def test_visible_text_empty_when_all_channels_blank(self): + ext = OpenRouterDeltaExtension() + assert ext.visible_text() == "" + + def test_empty_delta_produces_empty_extension(self): + ext = OpenRouterDeltaExtension.from_delta(_delta()) + assert ext.reasoning is None + assert ext.reasoning_content is None + assert ext.reasoning_details == [] + + def test_malformed_reasoning_payload_logged_and_swallowed(self, caplog): + # A malformed payload (e.g. reasoning_details shipped as a string + # rather than a list) must not abort the stream — log it and + # return an empty extension so valid text/tool events keep flowing. + # A plain mock is used here because ``from_delta`` only reads + # ``delta.model_extra`` — avoids reaching into pydantic internals + # (``__pydantic_extra__``) that could be renamed across versions. + from unittest.mock import MagicMock + + delta = MagicMock(spec=ChoiceDelta) + delta.model_extra = {"reasoning_details": "not a list"} + with caplog.at_level("WARNING"): + ext = OpenRouterDeltaExtension.from_delta(delta) + assert ext.reasoning_details == [] + assert ext.visible_text() == "" + assert any("malformed" in r.message.lower() for r in caplog.records) + + def test_unknown_typed_entry_with_text_is_not_surfaced(self): + # Regression: the legacy extractor emitted any entry with a + # ``text`` or ``summary`` field. The typed parser now filters on + # the recognised types so future provider metadata can't leak + # into the reasoning collapse. + ext = OpenRouterDeltaExtension( + reasoning_details=[ + ReasoningDetail(type="reasoning.future", text="provider metadata"), + ReasoningDetail(type="reasoning.text", text="real"), + ] + ) + assert ext.visible_text() == "real" + + +class TestIsReasoningRoute: + def test_anthropic_routes(self): + assert _is_reasoning_route("anthropic/claude-sonnet-4-6") + assert _is_reasoning_route("claude-3-5-sonnet-20241022") + assert _is_reasoning_route("anthropic.claude-3-5-sonnet") + assert _is_reasoning_route("ANTHROPIC/Claude-Opus") # case-insensitive + + def test_moonshot_kimi_routes(self): + # OpenRouter advertises the ``reasoning`` extension on Moonshot + # endpoints — both K2.6 (the new baseline default) and the + # reasoning-native kimi-k2-thinking variant. + assert _is_reasoning_route("moonshotai/kimi-k2.6") + assert _is_reasoning_route("moonshotai/kimi-k2-thinking") + assert _is_reasoning_route("moonshotai/kimi-k2.5") + # Direct (non-OpenRouter) model ids also resolve via the ``kimi-`` + # prefix so a future bare ``kimi-k3`` id would still match. + assert _is_reasoning_route("kimi-k2-instruct") + # Provider-prefixed bare kimi ids (without the ``moonshotai/`` + # prefix) are also recognised — the match anchors on the final + # path segment. + assert _is_reasoning_route("openrouter/kimi-k2.6") + + def test_other_providers_rejected(self): + assert not _is_reasoning_route("openai/gpt-4o") + assert not _is_reasoning_route("google/gemini-2.5-pro") + assert not _is_reasoning_route("xai/grok-4") + assert not _is_reasoning_route("meta-llama/llama-3.3-70b-instruct") + assert not _is_reasoning_route("deepseek/deepseek-r1") + + def test_kimi_substring_false_positives_rejected(self): + # Regression: the previous implementation matched any model whose + # name contained the substring ``kimi`` — including unrelated model + # ids like ``hakimi``. The anchored match below rejects them. + assert not _is_reasoning_route("some-provider/hakimi-large") + assert not _is_reasoning_route("hakimi") + assert not _is_reasoning_route("akimi-7b") + + def test_claude_substring_false_positives_rejected(self): + # Regression (Sentry review on #12871): ``'claude' in lowered`` + # matched any substring — a custom + # ``someprovider/claude-mock-v1`` set via + # ``CHAT_FAST_STANDARD_MODEL`` would inherit the reasoning + # extra_body and take a 400 from its upstream. The anchored + # match requires either an ``anthropic`` / ``anthropic.`` / + # ``anthropic/`` prefix, or a bare ``claude-`` id with no + # provider prefix. + assert not _is_reasoning_route("someprovider/claude-mock-v1") + assert not _is_reasoning_route("custom/claude-like-model") + # Same principle for Kimi — a non-Moonshot provider prefix is + # rejected even when the model id starts with ``kimi-``. + assert not _is_reasoning_route("other/kimi-pro") + + +class TestReasoningExtraBody: + def test_anthropic_route_returns_fragment(self): + assert reasoning_extra_body("anthropic/claude-sonnet-4-6", 4096) == { + "reasoning": {"max_tokens": 4096} + } + + def test_direct_claude_model_id_still_matches(self): + assert reasoning_extra_body("claude-3-5-sonnet-20241022", 2048) == { + "reasoning": {"max_tokens": 2048} + } + + def test_kimi_routes_return_fragment(self): + # Kimi K2.6 ships the same OpenRouter ``reasoning`` extension as + # Anthropic, so the gate widened with this PR and the fragment + # must now materialise on Moonshot routes too. + assert reasoning_extra_body("moonshotai/kimi-k2.6", 8192) == { + "reasoning": {"max_tokens": 8192} + } + assert reasoning_extra_body("moonshotai/kimi-k2-thinking", 4096) == { + "reasoning": {"max_tokens": 4096} + } + + def test_non_reasoning_route_returns_none(self): + assert reasoning_extra_body("openai/gpt-4o", 4096) is None + assert reasoning_extra_body("google/gemini-2.5-pro", 4096) is None + assert reasoning_extra_body("xai/grok-4", 4096) is None + + def test_zero_max_tokens_kill_switch(self): + # Operator kill switch: ``max_thinking_tokens <= 0`` disables the + # ``reasoning`` extra_body fragment on ANY reasoning route (Anthropic + # or Kimi). Lets us silence reasoning without dropping the SDK + # path's budget. + assert reasoning_extra_body("anthropic/claude-sonnet-4-6", 0) is None + assert reasoning_extra_body("anthropic/claude-sonnet-4-6", -1) is None + assert reasoning_extra_body("moonshotai/kimi-k2.6", 0) is None + + +class TestBaselineReasoningEmitter: + def test_first_text_delta_emits_start_then_delta(self): + emitter = BaselineReasoningEmitter() + events = emitter.on_delta(_delta(reasoning="thinking")) + + assert len(events) == 2 + assert isinstance(events[0], StreamReasoningStart) + assert isinstance(events[1], StreamReasoningDelta) + assert events[0].id == events[1].id + assert events[1].delta == "thinking" + assert emitter.is_open is True + + def test_subsequent_deltas_reuse_block_id_without_new_start(self): + # Disable coalescing so each chunk flushes immediately — this test + # is about the Start/Delta/block-id state machine, not the coalesce + # window. Coalescing behaviour is covered below. + emitter = BaselineReasoningEmitter( + coalesce_min_chars=0, coalesce_max_interval_ms=0 + ) + first = emitter.on_delta(_delta(reasoning="a")) + second = emitter.on_delta(_delta(reasoning="b")) + + assert any(isinstance(e, StreamReasoningStart) for e in first) + assert all(not isinstance(e, StreamReasoningStart) for e in second) + assert len(second) == 1 + assert isinstance(second[0], StreamReasoningDelta) + assert first[0].id == second[0].id + + def test_empty_delta_emits_nothing(self): + emitter = BaselineReasoningEmitter() + assert emitter.on_delta(_delta(content="hello")) == [] + assert emitter.is_open is False + + def test_close_emits_end_and_rotates_id(self): + emitter = BaselineReasoningEmitter() + # Capture the block id from the wire event rather than reaching + # into emitter internals — the id on the emitted Start/Delta is + # what the frontend actually receives. + start_events = emitter.on_delta(_delta(reasoning="x")) + first_id = start_events[0].id + + events = emitter.close() + assert len(events) == 1 + assert isinstance(events[0], StreamReasoningEnd) + assert events[0].id == first_id + assert emitter.is_open is False + # Next reasoning uses a fresh id. + new_events = emitter.on_delta(_delta(reasoning="y")) + assert isinstance(new_events[0], StreamReasoningStart) + assert new_events[0].id != first_id + + def test_close_is_idempotent(self): + emitter = BaselineReasoningEmitter() + assert emitter.close() == [] + emitter.on_delta(_delta(reasoning="x")) + assert len(emitter.close()) == 1 + assert emitter.close() == [] + + def test_structured_details_round_trip(self): + emitter = BaselineReasoningEmitter() + events = emitter.on_delta( + _delta( + reasoning_details=[ + {"type": "reasoning.text", "text": "plan: "}, + {"type": "reasoning.summary", "summary": "do the thing"}, + ] + ) + ) + deltas = [e for e in events if isinstance(e, StreamReasoningDelta)] + assert len(deltas) == 1 + assert deltas[0].delta == "plan: do the thing" + + +class TestReasoningDeltaCoalescing: + """Coalescing batches fine-grained provider chunks into bigger wire + frames. OpenRouter's Kimi K2.6 emits ~4,700 reasoning-delta chunks + per turn vs ~28 for Sonnet; without batching, every chunk becomes one + Redis ``xadd`` + one SSE event + one React re-render of the + non-virtualised chat list, which paint-storms the browser. These + tests pin the batching contract: small chunks buffer until the + char-size or time threshold trips, large chunks still flush + immediately, and ``close()`` never drops tail text.""" + + def test_small_chunks_after_first_buffer_until_threshold(self): + # Generous time threshold so size alone controls flush timing. + emitter = BaselineReasoningEmitter( + coalesce_min_chars=32, coalesce_max_interval_ms=60_000 + ) + # First chunk always flushes immediately (so UI renders without + # waiting). + first = emitter.on_delta(_delta(reasoning="hi ")) + assert any(isinstance(e, StreamReasoningStart) for e in first) + assert sum(isinstance(e, StreamReasoningDelta) for e in first) == 1 + + # Subsequent small chunks buffer silently — 5 × 4 chars = 20 chars, + # still under the 32-char threshold. + for _ in range(5): + assert emitter.on_delta(_delta(reasoning="abcd")) == [] + + # Once the threshold is crossed, the accumulated buffer flushes + # as a single StreamReasoningDelta carrying every buffered chunk. + flush = emitter.on_delta(_delta(reasoning="efghijklmnop")) + assert len(flush) == 1 + assert isinstance(flush[0], StreamReasoningDelta) + assert flush[0].delta == "abcd" * 5 + "efghijklmnop" + + def test_time_based_flush_when_chars_stay_below_threshold(self, monkeypatch): + # Fake ``time.monotonic`` so we can drive the time-based branch + # deterministically without real sleeps. + from backend.copilot.baseline import reasoning as rmod + + fake_now = [0.0] + monkeypatch.setattr(rmod.time, "monotonic", lambda: fake_now[0]) + + emitter = BaselineReasoningEmitter( + coalesce_min_chars=1000, coalesce_max_interval_ms=40 + ) + # t=0: first chunk flushes immediately. + first = emitter.on_delta(_delta(reasoning="a")) + assert sum(isinstance(e, StreamReasoningDelta) for e in first) == 1 + + # t=10 ms: still under 40 ms → buffer. + fake_now[0] = 0.010 + assert emitter.on_delta(_delta(reasoning="b")) == [] + + # t=50 ms since last flush → time threshold trips, flush fires. + fake_now[0] = 0.060 + flushed = emitter.on_delta(_delta(reasoning="c")) + assert len(flushed) == 1 + assert isinstance(flushed[0], StreamReasoningDelta) + assert flushed[0].delta == "bc" + + def test_close_flushes_tail_buffer_before_end(self): + emitter = BaselineReasoningEmitter( + coalesce_min_chars=1000, coalesce_max_interval_ms=60_000 + ) + emitter.on_delta(_delta(reasoning="first")) # flushes (first chunk) + emitter.on_delta(_delta(reasoning=" middle ")) # buffered + emitter.on_delta(_delta(reasoning="tail")) # buffered + + events = emitter.close() + assert len(events) == 2 + assert isinstance(events[0], StreamReasoningDelta) + assert events[0].delta == " middle tail" + assert isinstance(events[1], StreamReasoningEnd) + + def test_coalesce_disabled_flushes_every_chunk(self): + emitter = BaselineReasoningEmitter( + coalesce_min_chars=0, coalesce_max_interval_ms=0 + ) + first = emitter.on_delta(_delta(reasoning="a")) + second = emitter.on_delta(_delta(reasoning="b")) + assert sum(isinstance(e, StreamReasoningDelta) for e in first) == 1 + assert sum(isinstance(e, StreamReasoningDelta) for e in second) == 1 + + def test_persistence_stays_per_delta_even_when_wire_coalesces(self): + """DB row content must track every chunk so a crash mid-turn + persists the full reasoning-so-far, even if the coalesce window + never flushed those chunks to the wire.""" + session: list[ChatMessage] = [] + emitter = BaselineReasoningEmitter( + session, + coalesce_min_chars=1000, + coalesce_max_interval_ms=60_000, + ) + emitter.on_delta(_delta(reasoning="first ")) + emitter.on_delta(_delta(reasoning="chunk ")) + emitter.on_delta(_delta(reasoning="three")) + # No close; verify the persisted row already has everything. + assert len(session) == 1 + assert session[0].content == "first chunk three" + + +class TestReasoningPersistence: + """The persistence contract: without ``role="reasoning"`` rows in + session.messages, useHydrateOnStreamEnd overwrites the live-streamed + reasoning parts and the Reasoning collapse vanishes. Every delta + must be reflected in the persisted row the moment it's emitted.""" + + def test_session_row_appended_on_first_delta(self): + session: list[ChatMessage] = [] + emitter = BaselineReasoningEmitter(session) + + assert session == [] + emitter.on_delta(_delta(reasoning="hi")) + assert len(session) == 1 + assert session[0].role == "reasoning" + assert session[0].content == "hi" + + def test_subsequent_deltas_mutate_same_row(self): + session: list[ChatMessage] = [] + emitter = BaselineReasoningEmitter(session) + + emitter.on_delta(_delta(reasoning="part one ")) + emitter.on_delta(_delta(reasoning="part two")) + + assert len(session) == 1 + assert session[0].content == "part one part two" + + def test_close_keeps_row_in_session(self): + session: list[ChatMessage] = [] + emitter = BaselineReasoningEmitter(session) + + emitter.on_delta(_delta(reasoning="thought")) + emitter.close() + + assert len(session) == 1 + assert session[0].content == "thought" + + def test_second_reasoning_block_appends_new_row(self): + session: list[ChatMessage] = [] + emitter = BaselineReasoningEmitter(session) + + emitter.on_delta(_delta(reasoning="first")) + emitter.close() + emitter.on_delta(_delta(reasoning="second")) + + assert len(session) == 2 + assert [m.content for m in session] == ["first", "second"] + + def test_no_session_means_no_persistence(self): + """Emitter without attached session list emits wire events only.""" + emitter = BaselineReasoningEmitter() + events = emitter.on_delta(_delta(reasoning="pure wire")) + assert len(events) == 2 # start + delta, no crash + # Nothing else to assert — just proves None session is supported. + + +class TestBaselineReasoningEmitterRenderFlag: + """``render_in_ui=False`` must silence ``StreamReasoning*`` wire events + AND drop persistence of ``role="reasoning"`` rows — the operator hides + the collapse on both the live wire and on reload. Persistence is tied + to the wire events because the frontend's hydration path unconditionally + re-renders persisted reasoning rows; keeping them would make the flag a + no-op post-reload. These tests pin the contract in both directions so + future refactors can't flip only one half.""" + + def test_render_off_suppresses_start_and_delta(self): + emitter = BaselineReasoningEmitter(render_in_ui=False) + events = emitter.on_delta(_delta(reasoning="hidden")) + # No wire events, but state advanced (is_open == True) so close() + # below has something to rotate. + assert events == [] + assert emitter.is_open is True + + def test_render_off_suppresses_close_end(self): + emitter = BaselineReasoningEmitter(render_in_ui=False) + emitter.on_delta(_delta(reasoning="hidden")) + events = emitter.close() + assert events == [] + assert emitter.is_open is False + + def test_render_off_still_persists(self): + """Persistence is decoupled from the render flag — session + transcript always keeps the ``role="reasoning"`` row so audit + and ``--resume``-equivalent replay never lose thinking text. + The frontend gates rendering separately.""" + session: list[ChatMessage] = [] + emitter = BaselineReasoningEmitter(session, render_in_ui=False) + + emitter.on_delta(_delta(reasoning="part one ")) + emitter.on_delta(_delta(reasoning="part two")) + emitter.close() + + assert len(session) == 1 + assert session[0].role == "reasoning" + assert session[0].content == "part one part two" + + def test_render_off_rotates_block_id_between_sessions(self): + """Even with wire events silenced the block id must rotate on close, + otherwise a hypothetical mid-session flip would reuse a stale id.""" + emitter = BaselineReasoningEmitter(render_in_ui=False) + emitter.on_delta(_delta(reasoning="first")) + first_block_id = emitter._block_id + emitter.close() + emitter.on_delta(_delta(reasoning="second")) + assert emitter._block_id != first_block_id + + def test_render_on_is_default(self): + """Defaulting to True preserves backward compat — existing callers + that don't pass the kwarg keep emitting wire events as before.""" + emitter = BaselineReasoningEmitter() + events = emitter.on_delta(_delta(reasoning="hello")) + assert len(events) == 2 + assert isinstance(events[0], StreamReasoningStart) + assert isinstance(events[1], StreamReasoningDelta) diff --git a/autogpt_platform/backend/backend/copilot/baseline/service.py b/autogpt_platform/backend/backend/copilot/baseline/service.py index 4c6ad04d60..a8866026ce 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service.py @@ -15,16 +15,26 @@ import re import shutil import tempfile import uuid -from collections.abc import AsyncGenerator, Sequence +from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence from dataclasses import dataclass, field from functools import partial from typing import TYPE_CHECKING, Any, cast import orjson from langfuse import propagate_attributes +from openai.types import CompletionUsage from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam +from openai.types.completion_usage import PromptTokensDetails from opentelemetry import trace as otel_trace +from backend.copilot.baseline.reasoning import ( + BaselineReasoningEmitter, + reasoning_extra_body, +) +from backend.copilot.builder_context import ( + build_builder_context_turn_prefix, + build_builder_system_prompt_suffix, +) from backend.copilot.config import CopilotLlmModel, CopilotMode from backend.copilot.context import get_workspace_manager, set_execution_context from backend.copilot.graphiti.config import is_enabled_for_user @@ -35,10 +45,11 @@ from backend.copilot.model import ( maybe_append_user_message, upsert_chat_session, ) +from backend.copilot.model_router import resolve_model +from backend.copilot.moonshot import is_moonshot_model from backend.copilot.pending_message_helpers import ( combine_pending_with_current, drain_pending_safe, - pending_texts_from, persist_pending_as_user_rows, persist_session_safe, ) @@ -46,7 +57,7 @@ from backend.copilot.pending_messages import ( drain_pending_messages, format_pending_as_user_message, ) -from backend.copilot.prompting import get_baseline_supplement, get_graphiti_supplement +from backend.copilot.prompting import SHARED_TOOL_NOTES, get_graphiti_supplement from backend.copilot.response_model import ( StreamBaseResponse, StreamError, @@ -70,9 +81,10 @@ from backend.copilot.service import ( inject_user_context, strip_user_context_tags, ) +from backend.copilot.session_cleanup import prune_orphan_tool_calls from backend.copilot.thinking_stripper import ThinkingStripper as _ThinkingStripper from backend.copilot.token_tracking import persist_and_record_usage -from backend.copilot.tools import execute_tool, get_available_tools +from backend.copilot.tools import ToolGroup, execute_tool, get_available_tools from backend.copilot.tracking import track_user_message from backend.copilot.transcript import ( STOP_REASON_END_TURN, @@ -109,8 +121,57 @@ logger = logging.getLogger(__name__) # Set to hold background tasks to prevent garbage collection _background_tasks: set[asyncio.Task[Any]] = set() -# Maximum number of tool-call rounds before forcing a text response. -_MAX_TOOL_ROUNDS = 30 +# Hint appended on the last tool round so the model wraps up with a summary +# instead of issuing another tool call that gets cut off cold. The shared +# ``tool_call_loop`` drops ``tools`` on the last iteration (see util/tool_call_loop.py), +# so the model is forced to produce text and always finishes naturally. +_LAST_ITERATION_HINT = ( + "You have reached the tool-call budget for this turn. Do not call any " + "more tools — produce a final text response summarizing what you did, " + "what remains, and how the user can continue the work in the next turn." +) + +# Fallback surfaced when the tool-round budget is exhausted *and* the forced- +# text last round left the user with zero visible response. +_BUDGET_EXHAUSTED_FALLBACK_TEXT = ( + "Reached the tool-call budget for this turn. " + "Send a follow-up message to continue from here." +) + + +def _budget_exhausted_notice_text(terminal_round_text: str) -> str | None: + """Return the fallback notice when a budget-exhausted turn produced no + visible text, or ``None`` when the model already summarised itself. + + ``terminal_round_text`` is the text added by the *final* round only — + earlier-round chatter shouldn't mask a silent final round. + """ + if terminal_round_text.strip(): + return None + return _BUDGET_EXHAUSTED_FALLBACK_TEXT + + +def _build_budget_exhausted_fallback_events( + terminal_round_text: str, +) -> tuple[list[StreamBaseResponse], str]: + """Build the fallback stream events surfaced when a budget-exhausted + turn left the terminal round with no visible text. + + Returns ``(events, text_to_append)``. Empty list + empty string when + no fallback is needed. Split out of the async generator so it's unit- + testable without the surrounding streaming machinery. + """ + notice = _budget_exhausted_notice_text(terminal_round_text) + if notice is None: + return [], "" + block_id = str(uuid.uuid4()) + events: list[StreamBaseResponse] = [ + StreamTextStart(id=block_id), + StreamTextDelta(id=block_id, delta=notice), + StreamTextEnd(id=block_id), + ] + return events, notice + # Max seconds to wait for transcript upload in the finally block before # letting it continue as a background task (tracked in _background_tasks). @@ -126,6 +187,78 @@ _MAX_INLINE_IMAGE_BYTES = 20 * 1024 * 1024 # Matches characters unsafe for filenames. _UNSAFE_FILENAME = re.compile(r"[^\w.\-]") +# OpenRouter-specific extra_body flag that embeds the real generation cost +# into the final usage chunk. Module-level constant so we don't reallocate +# an identical dict on every streaming call. +_OPENROUTER_INCLUDE_USAGE_COST = {"usage": {"include": True}} + + +def _extract_usage_cost(usage: CompletionUsage) -> float | None: + """Return the provider-reported USD cost on a streaming usage chunk. + + OpenRouter piggybacks a ``cost`` field on the OpenAI-compatible usage + object when the request body includes ``usage: {"include": True}``. + The OpenAI SDK's typed ``CompletionUsage`` does not declare it, so we + read it off ``model_extra`` (the pydantic v2 container for extras) to + keep the access fully typed — no ``getattr``. + + Returns ``None`` when the field is absent, explicitly null, + non-numeric, non-finite, or negative. Invalid values (including + present-but-null) are logged here — they indicate a provider bug + worth chasing; plain absences are silent so the caller can dedupe + the "missing cost" warning per stream. + """ + extras = usage.model_extra or {} + if "cost" not in extras: + return None + raw = extras["cost"] + if raw is None: + logger.error("[Baseline] usage.cost is present but null") + return None + try: + val = float(raw) + except (TypeError, ValueError): + logger.error("[Baseline] usage.cost is not numeric: %r", raw) + return None + if not math.isfinite(val) or val < 0: + logger.error("[Baseline] usage.cost is non-finite or negative: %r", val) + return None + return val + + +def _extract_cache_creation_tokens(ptd: PromptTokensDetails) -> int: + """Return cache-write token count from an OpenAI-compatible + ``PromptTokensDetails``, handling provider-specific field names and + SDK-version shape differences. + + Two shapes we care about: + + - **OpenRouter** (our primary baseline provider) streams the cache-write + count as ``cache_write_tokens``. Newer ``openai-python`` versions + declare this as a typed attribute on ``PromptTokensDetails``; older + versions expose it only in ``model_extra``. Verified empirically: + cold-cache request returns ``cache_write_tokens`` > 0, warm-cache + request returns ``cached_tokens`` > 0 and ``cache_write_tokens`` = 0. + - **Direct Anthropic API** uses ``cache_creation_input_tokens`` — + never a typed attribute on the OpenAI SDK, always lives in + ``model_extra``. + + Lookup order: typed attr → ``model_extra`` (OpenRouter) → ``model_extra`` + (Anthropic-native). ``getattr`` handles both the typed-attr case + (newer SDK) and the no-such-attr case (older SDK) — we can't only use + ``model_extra`` because when the field is typed it's filtered out of + ``model_extra``, leaving us at 0 on the modern happy path. + """ + typed_val = getattr(ptd, "cache_write_tokens", None) + if typed_val: + return int(typed_val) + extras = ptd.model_extra or {} + return int( + extras.get("cache_write_tokens") + or extras.get("cache_creation_input_tokens") + or 0 + ) + async def _prepare_baseline_attachments( file_ids: list[str], @@ -236,17 +369,17 @@ def _filter_tools_by_permissions( ] -def _resolve_baseline_model(tier: CopilotLlmModel | None) -> str: +async def _resolve_baseline_model( + tier: CopilotLlmModel | None, user_id: str | None +) -> str: """Pick the model for the baseline path based on the per-request tier. - The baseline (fast) and SDK (extended thinking) paths now share the - same tier-based model resolution — only the *path* differs between - "fast" and "extended_thinking". ``'advanced'`` → Opus; - ``'standard'`` / ``None`` → the config default (Sonnet). + Delegates to :func:`copilot.model_router.resolve_model` so the + ``(fast, tier)`` cell is LD-overridable per user. ``None`` tier + maps to ``"standard"``. """ - from backend.copilot.service import resolve_chat_model - - return resolve_chat_model(tier) + tier_name = "advanced" if tier == "advanced" else "standard" + return await resolve_model("fast", tier_name, user_id, config=config) @dataclass @@ -258,22 +391,224 @@ class _BaselineStreamState: """ model: str = "" - pending_events: list[StreamBaseResponse] = field(default_factory=list) + # Live delivery channel drained concurrently by ``stream_chat_completion_baseline`` + # so reasoning / text / tool events reach the SSE wire **during** the upstream + # LLM stream, not after ``_baseline_llm_caller`` returns. Before this was a + # ``list`` drained per ``tool_call_loop`` iteration, so any model with + # extended thinking (Anthropic via OpenRouter, Moonshot, future reasoning + # routes) froze the UI for the entire duration of each LLM round before + # flushing the backlog in one burst. The queue is single-producer (the + # streaming loop) / single-consumer (the outer async-gen yield loop); + # ``None`` is the close sentinel. + pending_events: asyncio.Queue[StreamBaseResponse | None] = field( + default_factory=asyncio.Queue + ) + # Mirror of every event put on ``pending_events`` — kept for unit tests that + # inspect post-hoc what was emitted. Not consumed by production code. + emitted_events: list[StreamBaseResponse] = field(default_factory=list) assistant_text: str = "" text_block_id: str = field(default_factory=lambda: str(uuid.uuid4())) text_started: bool = False + reasoning_emitter: BaselineReasoningEmitter = field(init=False) turn_prompt_tokens: int = 0 turn_completion_tokens: int = 0 turn_cache_read_tokens: int = 0 turn_cache_creation_tokens: int = 0 cost_usd: float | None = None + # Tracks whether we've already warned about a missing `cost` field in + # the usage chunk this stream, so non-OpenRouter providers don't + # generate one warning per streaming call. + cost_missing_logged: bool = False thinking_stripper: _ThinkingStripper = field(default_factory=_ThinkingStripper) + # MUTATE in place only — ``__post_init__`` hands this list reference to + # ``BaselineReasoningEmitter`` so reasoning rows can be appended as + # deltas stream in. Reassigning (``state.session_messages = [...]``) + # would silently detach the emitter from the new list. session_messages: list[ChatMessage] = field(default_factory=list) # Tracks how much of ``assistant_text`` has already been flushed to # ``session.messages`` via mid-loop pending drains, so the ``finally`` # block only appends the *new* assistant text (avoiding duplication of # round-1 text when round-1 entries were cleared from session_messages). _flushed_assistant_text_len: int = 0 + # Memoised system-message dict with cache_control applied. The system + # prompt is static within a session, so we build it once on the first + # LLM round and reuse the same dict on subsequent rounds — avoiding + # an O(N) dict-copy of the growing ``messages`` list on every tool-call + # iteration. ``None`` means "not yet computed" (or the first message + # wasn't a system role, so no marking applies). + cached_system_message: dict[str, Any] | None = None + + def __post_init__(self) -> None: + # Wire the reasoning emitter to ``session_messages`` so it can + # append ``role="reasoning"`` rows as reasoning streams in — the + # frontend's ``convertChatSessionToUiMessages`` relies on these + # rows to render the Reasoning collapse after the AI SDK's + # stream-end hydrate swaps in the DB-backed message list. + # ``render_in_ui`` is sourced from ``config.render_reasoning_in_ui`` + # so the operator can silence the reasoning collapse globally + # without dropping the persisted audit trail. + self.reasoning_emitter = BaselineReasoningEmitter( + self.session_messages, + render_in_ui=config.render_reasoning_in_ui, + ) + + +def _emit(state: "_BaselineStreamState", event: StreamBaseResponse) -> None: + """Queue *event* for the live SSE wire AND mirror into ``emitted_events``. + + Single helper so every streaming producer (LLM stream loop, tool executor, + conversation updater) posts to the same single-consumer queue. The mirror + list is read-only from production code — it exists so unit tests can assert + on the full sequence emitted during one call. + """ + state.pending_events.put_nowait(event) + state.emitted_events.append(event) + + +def _emit_all( + state: "_BaselineStreamState", events: Iterable[StreamBaseResponse] +) -> None: + """Queue *events* in order — convenience for emitter batches.""" + for event in events: + _emit(state, event) + + +def _is_anthropic_model(model: str) -> bool: + """Return True if *model* routes to Anthropic (native or via OpenRouter). + + Examples that return True: + - ``anthropic/claude-sonnet-4-6`` (OpenRouter route) + - ``claude-3-5-sonnet-20241022`` (direct Anthropic API) + - ``anthropic.claude-3-5-sonnet`` (Bedrock-style) + + False for ``openai/gpt-4o``, ``google/gemini-2.5-pro``, ``xai/grok-4`` + etc. Moonshot is False here too even though Moonshot's + Anthropic-compat endpoint honours ``cache_control`` — use + :func:`_supports_prompt_cache_markers` for the cache-gating decision, + which also allows Moonshot routes. This function stays scoped to + "genuinely Anthropic" so callers that need the stricter check (e.g. + ``anthropic-beta`` header emission) keep their existing semantics. + """ + lowered = model.lower() + return "claude" in lowered or lowered.startswith("anthropic") + + +def _supports_prompt_cache_markers(model: str) -> bool: + """Return True when *model* accepts Anthropic-style ``cache_control``. + + Superset of :func:`_is_anthropic_model` — also allows Moonshot + (``moonshotai/*``), whose OpenRouter Anthropic-compat endpoint + honours the marker and empirically lifts cache hit rate on + continuation turns from near-zero (Moonshot's own automatic prefix + cache, which drifts readily) to the 60-95% Anthropic ballpark. + + OpenAI / Grok / Gemini still 400 on ``cache_control``, so this + function returns False for those providers — add new vendors here + only after verifying their endpoint accepts the field. + """ + return _is_anthropic_model(model) or is_moonshot_model(model) + + +def _fresh_ephemeral_cache_control() -> dict[str, str]: + """Return a FRESH ephemeral ``cache_control`` dict each call. + + The ``ttl`` is sourced from :attr:`ChatConfig.baseline_prompt_cache_ttl` + (default ``1h``) so the static prefix stays warm across many users' + requests in the same workspace cache. Anthropic caches are keyed + per-workspace, so every copilot user reading the same system prompt + hits the same cached entry. + + Using a shared module-level dict would let any downstream mutation + (e.g. the OpenAI SDK normalising fields in-place) poison every future + request's marker. Construction is O(1) so the safety margin is free. + """ + return {"type": "ephemeral", "ttl": config.baseline_prompt_cache_ttl} + + +def _fresh_anthropic_caching_headers() -> dict[str, str]: + """Return a FRESH ``extra_headers`` dict requesting the Anthropic + prompt-caching beta. + + Same reasoning as :func:`_fresh_ephemeral_cache_control`: never hand a + shared module-level dict to third-party SDKs. OpenRouter auto-forwards + cache_control for Anthropic routes without this header, but passing it + makes the intent unambiguous on-wire and is a no-op for non-Anthropic + providers (unknown headers are dropped). + """ + return {"anthropic-beta": "prompt-caching-2024-07-31"} + + +def _mark_tools_with_cache_control( + tools: Sequence[Mapping[str, Any]], +) -> list[dict[str, Any]]: + """Return a copy of *tools* with ``cache_control`` on the last entry. + + Marking the last tool is a cache breakpoint that covers the whole tool + schema block as a cacheable prefix segment. Extracted from + :func:`_mark_system_message_with_cache_control` so callers can precompute + the marked tool list once per session — the tool set is static within a + request and the ~43 dict-copies would otherwise run on every LLM round + in the tool-call loop. + + **Only call this for Anthropic model routes.** Non-Anthropic providers + (OpenAI, Grok, Gemini) reject the unknown ``cache_control`` field with + a 400 schema validation error. Gate via :func:`_is_anthropic_model`. + """ + cached: list[dict[str, Any]] = [dict(t) for t in tools] + if cached: + cached[-1] = { + **cached[-1], + "cache_control": _fresh_ephemeral_cache_control(), + } + return cached + + +def _build_cached_system_message( + system_message: Mapping[str, Any], +) -> dict[str, Any]: + """Return a copy of *system_message* with ``cache_control`` applied. + + Anthropic's cache uses prefix-match with up to 4 explicit breakpoints. + Combined with the last-tool marker this gives two cache segments — the + system block alone, and system+all-tools — so requests that share only + the system prefix still get a partial cache hit. + + The system message is rebuilt via spread (``{**original, ...}``) so any + unknown fields the caller set (e.g. ``name``) survive the transformation. + Non-Anthropic models silently ignore the markers. + + Returns the original dict (shallow-copied) unchanged when the content + shape is unsupported (missing / non-string / empty) — callers should + splice it into the message list as-is in that case. + """ + sys_copy = dict(system_message) + sys_content = sys_copy.get("content") + if isinstance(sys_content, str) and sys_content: + sys_copy["content"] = [ + { + "type": "text", + "text": sys_content, + "cache_control": _fresh_ephemeral_cache_control(), + } + ] + return sys_copy + + +def _mark_system_message_with_cache_control( + messages: Sequence[Mapping[str, Any]], +) -> list[dict[str, Any]]: + """Return a copy of *messages* with ``cache_control`` on the system block. + + Thin wrapper around :func:`_build_cached_system_message` that preserves + the original list shape. Prefer the memoised path in + ``_baseline_llm_caller`` (which builds the cached system dict once per + session) for hot-loop callers; this function is retained for call sites + outside the tool-call loop where per-call copying is acceptable. + """ + cached_messages: list[dict[str, Any]] = [dict(m) for m in messages] + if cached_messages and cached_messages[0].get("role") == "system": + cached_messages[0] = _build_cached_system_message(cached_messages[0]) + return cached_messages async def _baseline_llm_caller( @@ -286,32 +621,74 @@ async def _baseline_llm_caller( Extracted from ``stream_chat_completion_baseline`` for readability. """ - state.pending_events.append(StreamStartStep()) + _emit(state, StreamStartStep()) # Fresh thinking-strip state per round so a malformed unclosed # block in one LLM call cannot silently drop content in the next. state.thinking_stripper = _ThinkingStripper() round_text = "" - response = None # initialized before try so finally block can access it try: client = _get_openai_client() - typed_messages = cast(list[ChatCompletionMessageParam], messages) - if tools: - typed_tools = cast(list[ChatCompletionToolParam], tools) - response = await client.chat.completions.create( - model=state.model, - messages=typed_messages, - tools=typed_tools, - stream=True, - stream_options={"include_usage": True}, + # Cache markers are accepted by Anthropic AND Moonshot (via OR's + # Anthropic-compat endpoint). OpenAI/Grok/Gemini 400 on the + # unknown ``cache_control`` field — tools were precomputed in + # stream_chat_completion_baseline via _mark_tools_with_cache_control + # with the same gate, so on unsupported routes tools ship + # unmarked too. + # + # The ``anthropic-beta`` header is only emitted for genuinely + # Anthropic routes (see :func:`_is_anthropic_model`) — Moonshot + # doesn't need the beta header; sending it is a no-op but we + # keep the check strict for clarity. + # + # `extra_body` `usage.include=true` asks OpenRouter to embed the + # real generation cost into the final usage chunk — required by + # the cost-based rate limiter in routes.py. Separate from the + # caching headers, always sent. + supports_cache = _supports_prompt_cache_markers(state.model) + if supports_cache: + # Build the cached system dict once per session and splice it in + # on each round. The full ``messages`` list grows with every + # tool call, so copying the entire list just to mutate index 0 + # scales with conversation length (sentry flagged this); this + # splice touches only list slots, not message contents. + if ( + state.cached_system_message is None + and messages + and messages[0].get("role") == "system" + ): + state.cached_system_message = _build_cached_system_message(messages[0]) + if state.cached_system_message is not None and messages: + final_messages = [state.cached_system_message, *messages[1:]] + else: + final_messages = messages + extra_headers = ( + _fresh_anthropic_caching_headers() + if _is_anthropic_model(state.model) + else None ) else: - response = await client.chat.completions.create( - model=state.model, - messages=typed_messages, - stream=True, - stream_options={"include_usage": True}, - ) + final_messages = messages + extra_headers = None + typed_messages = cast(list[ChatCompletionMessageParam], final_messages) + extra_body: dict[str, Any] = dict(_OPENROUTER_INCLUDE_USAGE_COST) + reasoning_param = reasoning_extra_body( + state.model, config.claude_agent_max_thinking_tokens + ) + if reasoning_param: + extra_body.update(reasoning_param) + create_kwargs: dict[str, Any] = { + "model": state.model, + "messages": typed_messages, + "stream": True, + "stream_options": {"include_usage": True}, + "extra_body": extra_body, + } + if extra_headers: + create_kwargs["extra_headers"] = extra_headers + if tools: + create_kwargs["tools"] = cast(list[ChatCompletionToolParam], list(tools)) + response = await client.chat.completions.create(**create_kwargs) tool_calls_by_index: dict[int, dict[str, str]] = {} # Iterate under an inner try/finally so early exits (cancel, tool-call @@ -323,37 +700,62 @@ async def _baseline_llm_caller( if chunk.usage: state.turn_prompt_tokens += chunk.usage.prompt_tokens or 0 state.turn_completion_tokens += chunk.usage.completion_tokens or 0 - # Extract cache token details when available (OpenAI / - # OpenRouter include these in prompt_tokens_details). - ptd = getattr(chunk.usage, "prompt_tokens_details", None) + ptd = chunk.usage.prompt_tokens_details if ptd: - state.turn_cache_read_tokens += ( - getattr(ptd, "cached_tokens", 0) or 0 - ) - # cache_creation_input_tokens is reported by some providers - # (e.g. Anthropic native) but not standard OpenAI streaming. + state.turn_cache_read_tokens += ptd.cached_tokens or 0 state.turn_cache_creation_tokens += ( - getattr(ptd, "cache_creation_input_tokens", 0) or 0 + _extract_cache_creation_tokens(ptd) ) + cost = _extract_usage_cost(chunk.usage) + if cost is not None: + state.cost_usd = (state.cost_usd or 0.0) + cost + elif ( + "cost" not in (chunk.usage.model_extra or {}) + and not state.cost_missing_logged + ): + # Field absent (non-OpenRouter route, or OpenRouter + # misconfigured) — warn once per stream so error + # monitoring picks up persistent misses without + # flooding. Invalid values already logged inside + # _extract_usage_cost, so no duplicate warning here. + logger.warning( + "[Baseline] usage chunk missing cost (model=%s, " + "prompt=%s, completion=%s) — rate-limit will " + "skip this call", + state.model, + chunk.usage.prompt_tokens, + chunk.usage.completion_tokens, + ) + state.cost_missing_logged = True delta = chunk.choices[0].delta if chunk.choices else None if not delta: continue + _emit_all(state, state.reasoning_emitter.on_delta(delta)) + if delta.content: + # Text and reasoning must not interleave on the wire — the + # AI SDK maps distinct start/end pairs to distinct UI + # parts. Close any open reasoning block before emitting + # the first text delta of this run. + _emit_all(state, state.reasoning_emitter.close()) emit = state.thinking_stripper.process(delta.content) if emit: if not state.text_started: - state.pending_events.append( - StreamTextStart(id=state.text_block_id) - ) + _emit(state, StreamTextStart(id=state.text_block_id)) state.text_started = True round_text += emit - state.pending_events.append( - StreamTextDelta(id=state.text_block_id, delta=emit) + _emit( + state, + StreamTextDelta(id=state.text_block_id, delta=emit), ) if delta.tool_calls: + # Same rule as the text branch: close any open reasoning + # block before a tool_use starts so the AI SDK treats + # reasoning and tool-use as distinct parts. + _emit_all(state, state.reasoning_emitter.close()) for tc in delta.tool_calls: idx = tc.index if idx not in tool_calls_by_index: @@ -378,42 +780,31 @@ async def _baseline_llm_caller( except Exception: pass + finally: + # Close open blocks on both normal and exception paths so the + # frontend always sees matched start/end pairs. An exception mid + # ``async for chunk in response`` would otherwise leave reasoning + # and/or text unterminated and only ``StreamFinishStep`` emitted — + # the Reasoning / Text collapses would never finalise. + _emit_all(state, state.reasoning_emitter.close()) # Flush any buffered text held back by the thinking stripper. tail = state.thinking_stripper.flush() if tail: if not state.text_started: - state.pending_events.append(StreamTextStart(id=state.text_block_id)) + _emit(state, StreamTextStart(id=state.text_block_id)) state.text_started = True round_text += tail - state.pending_events.append( - StreamTextDelta(id=state.text_block_id, delta=tail) - ) - # Close text block + _emit(state, StreamTextDelta(id=state.text_block_id, delta=tail)) if state.text_started: - state.pending_events.append(StreamTextEnd(id=state.text_block_id)) + _emit(state, StreamTextEnd(id=state.text_block_id)) state.text_started = False state.text_block_id = str(uuid.uuid4()) - finally: - # Extract OpenRouter cost from response headers (in finally so we - # capture cost even when the stream errors mid-way — we already paid). - # Accumulate across multi-round tool-calling turns. - try: - # Access undocumented _response attribute — same pattern as - # extract_openrouter_cost() in blocks/llm.py. - cost_header = response._response.headers.get("x-total-cost") # type: ignore[attr-defined] - if cost_header: - cost = float(cost_header) - if math.isfinite(cost) and cost >= 0: - state.cost_usd = (state.cost_usd or 0.0) + cost - except (AttributeError, ValueError): - pass - # Always persist partial text so the session history stays consistent, # even when the stream is interrupted by an exception. state.assistant_text += round_text # Always emit StreamFinishStep to match the StreamStartStep, # even if an exception occurred during streaming. - state.pending_events.append(StreamFinishStep()) + _emit(state, StreamFinishStep()) # Convert to shared format llm_tool_calls = [ @@ -455,13 +846,14 @@ async def _baseline_tool_executor( except orjson.JSONDecodeError as parse_err: parse_error = f"Invalid JSON arguments for tool '{tool_name}': {parse_err}" logger.warning("[Baseline] %s", parse_error) - state.pending_events.append( + _emit( + state, StreamToolOutputAvailable( toolCallId=tool_call_id, toolName=tool_name, output=parse_error, success=False, - ) + ), ) return ToolCallResult( tool_call_id=tool_call_id, @@ -470,17 +862,32 @@ async def _baseline_tool_executor( is_error=True, ) - state.pending_events.append( - StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name) + _emit( + state, + StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name), ) - state.pending_events.append( + _emit( + state, StreamToolInputAvailable( toolCallId=tool_call_id, toolName=tool_name, input=tool_args, - ) + ), ) + # Announce the tool call to the session so in-turn guards like + # ``require_guide_read`` can see it *right now*, before the tool + # actually runs. Without this, the tool_call row lives only in + # ``state.session_messages`` until the ``finally`` block flushes it + # into ``session.messages`` at turn end — so a second tool in the + # same turn (e.g. ``create_agent`` after ``get_agent_building_guide``) + # scans a stale ``session.messages`` and the guard re-fires despite + # the guide having been called. The announce-set is cleared at turn + # end; we deliberately don't touch ``session.messages`` here to avoid + # duplicating the assistant row that ``_baseline_conversation_updater`` + # will append at round end. + session.announce_inflight_tool_call(tool_name) + try: result: StreamToolOutputAvailable = await execute_tool( tool_name=tool_name, @@ -489,7 +896,7 @@ async def _baseline_tool_executor( session=session, tool_call_id=tool_call_id, ) - state.pending_events.append(result) + _emit(state, result) tool_output = ( result.output if isinstance(result.output, str) else str(result.output) ) @@ -506,13 +913,14 @@ async def _baseline_tool_executor( error_output, exc_info=True, ) - state.pending_events.append( + _emit( + state, StreamToolOutputAvailable( toolCallId=tool_call_id, toolName=tool_name, output=error_output, success=False, - ) + ), ) return ToolCallResult( tool_call_id=tool_call_id, @@ -948,6 +1356,12 @@ async def stream_chat_completion_baseline( f"Session {session_id} not found. Please create a new session first." ) + # Drop orphan tool_use + trailing stop-marker rows left by a previous + # Stop mid-tool-call so the new turn starts from a well-formed message list. + prune_orphan_tool_calls( + session.messages, log_prefix=f"[Baseline] [{session_id[:12]}]" + ) + # Strip any user-injected tags on every turn. # Only the server-injected prefix on the first message is trusted. if message: @@ -982,7 +1396,6 @@ async def stream_chat_completion_baseline( len(drained_at_start_pending), session_id, ) - drained_at_start_content = pending_texts_from(drained_at_start_pending) # Chronological combine: pending typed BEFORE this /stream # request's arrival go ahead of ``message``; race-path follow-ups # typed AFTER (queued while /stream was still processing) go @@ -1016,7 +1429,7 @@ async def stream_chat_completion_baseline( # Select model based on the per-request tier toggle (standard / advanced). # The path (fast vs extended_thinking) is already decided — we're in the # baseline (fast) path; ``mode`` is accepted for logging parity only. - active_model = _resolve_baseline_model(model) + active_model = await _resolve_baseline_model(model, user_id) # --- E2B sandbox setup (feature parity with SDK path) --- e2b_sandbox = None @@ -1107,7 +1520,18 @@ async def stream_chat_completion_baseline( graphiti_enabled = await is_enabled_for_user(user_id) graphiti_supplement = get_graphiti_supplement() if graphiti_enabled else "" - system_prompt = base_system_prompt + get_baseline_supplement() + graphiti_supplement + # Append the builder-session block (graph id+name + full building guide) + # AFTER the shared supplements so the system prompt is byte-identical + # across turns of the same builder session — Claude's prompt cache keeps + # the ~20KB guide warm for the whole session. Empty string for + # non-builder sessions keeps the cross-user cache hot. + builder_session_suffix = await build_builder_system_prompt_suffix(session) + system_prompt = ( + base_system_prompt + + SHARED_TOOL_NOTES + + graphiti_supplement + + builder_session_suffix + ) # Warm context: pre-load relevant facts from Graphiti on first turn. # Use the pre-drain count so pending messages drained at turn start @@ -1191,6 +1615,26 @@ async def stream_chat_completion_baseline( # Do NOT append warm_ctx to user_message_for_transcript — it would # persist stale temporal context into the transcript for future turns. + # Inject the per-turn ```` prefix when the session is + # bound to a graph via ``metadata.builder_graph_id``. Runs on every + # user turn (not just the first) so the LLM always sees the live graph + # snapshot — if the user edits the graph between turns, the next turn + # carries the updated nodes/links. Only version + nodes + links here; + # the static guide + graph id live in the system prompt via + # ``build_builder_system_prompt_suffix`` (session-stable, prompt-cached). + # Prepended AFTER any // blocks + # — same trust tier as those server-injected prefixes. Not persisted to + # the transcript: the snapshot is stale-by-definition after the turn ends. + if is_user_message and session.metadata.builder_graph_id: + builder_block = await build_builder_context_turn_prefix(session, user_id) + if builder_block: + for msg in reversed(openai_messages): + if msg["role"] == "user": + existing = msg.get("content", "") + if isinstance(existing, str): + msg["content"] = builder_block + existing + break + # Append user message to transcript. # Always append when the message is present and is from the user, # even on duplicate-suppressed retries (is_new_message=False). @@ -1251,12 +1695,28 @@ async def stream_chat_completion_baseline( openai_messages[i]["content"] = text break - tools = get_available_tools() + disabled_tool_groups: list[ToolGroup] = [] + if not graphiti_enabled: + disabled_tool_groups.append("graphiti") + tools = get_available_tools(disabled_groups=disabled_tool_groups) # --- Permission filtering --- if permissions is not None: tools = _filter_tools_by_permissions(tools, permissions) + # Pre-mark cache_control on the last tool schema once per session. The + # tool set is static within a request, so doing this here (instead of in + # _baseline_llm_caller) avoids re-copying ~43 tool dicts on every LLM + # round of the tool-call loop. + # + # Applies to Anthropic AND Moonshot routes — OpenAI/Grok/Gemini 400 + # on the unknown ``cache_control`` field inside tool definitions, so + # the gate stays narrow (see :func:`_supports_prompt_cache_markers`). + if _supports_prompt_cache_markers(active_model): + tools = cast( + list[ChatCompletionToolParam], _mark_tools_with_cache_control(tools) + ) + # Propagate execution context so tool handlers can read session-level flags. set_execution_context( user_id, @@ -1316,179 +1776,250 @@ async def stream_chat_completion_baseline( state=state, ) - try: - loop_result = None - async for loop_result in tool_call_loop( - messages=openai_messages, - tools=tools, - llm_call=_bound_llm_caller, - execute_tool=_bound_tool_executor, - update_conversation=_bound_conversation_updater, - max_iterations=_MAX_TOOL_ROUNDS, - ): - # Drain buffered events after each iteration (real-time streaming) - for evt in state.pending_events: - yield evt - state.pending_events.clear() + # Run the tool-call loop concurrently with the event consumer so + # ``StreamReasoning*`` / ``StreamText*`` deltas emitted inside + # ``_baseline_llm_caller`` reach the SSE wire DURING the upstream LLM + # stream instead of only at iteration boundaries. Any reasoning route + # that streams for several minutes per round (extended thinking on + # Anthropic / Moonshot / future providers) would otherwise freeze the + # UI for the whole window before flushing the backlog in one burst. + loop_result_holder: list[Any] = [None] + loop_task: asyncio.Task[None] | None = None + # Length of ``state.assistant_text`` at the end of the last non-final + # yield — used as an anchor by the budget-exhausted fallback to check + # whether the *terminal* round produced any visible text, not the whole + # turn. Without this, earlier-round chatter would suppress a fallback + # that should fire. + text_len_before_final_round: list[int] = [0] - # Inject any messages the user queued while the turn was - # running. ``tool_call_loop`` mutates ``openai_messages`` - # in-place, so appending here means the model sees the new - # messages on its next LLM call. - # - # IMPORTANT: skip when the loop has already finished (no - # more LLM calls are coming). ``tool_call_loop`` yields - # a final ``ToolCallLoopResult`` on both paths: - # - natural finish: ``finished_naturally=True`` - # - hit max_iterations: ``finished_naturally=False`` - # and ``iterations >= max_iterations`` - # In either case the loop is about to return on the next - # ``async for`` step, so draining here would silently - # lose the message (the user sees 202 but the model never - # reads the text). Those messages stay in the buffer and - # get picked up at the start of the next turn. - is_final_yield = ( - loop_result.finished_naturally - or loop_result.iterations >= _MAX_TOOL_ROUNDS - ) - if is_final_yield: - continue - try: - pending = await drain_pending_messages(session_id) - except Exception: - logger.warning( - "[Baseline] mid-loop drain_pending_messages failed for session %s", - session_id, - exc_info=True, - ) - pending = [] - if pending: - # Flush any buffered assistant/tool messages from completed - # rounds into session.messages BEFORE appending the pending - # user message. ``_baseline_conversation_updater`` only - # records assistant+tool rounds into ``state.session_messages`` - # — they are normally batch-flushed in the finally block. - # Without this in-order flush, the mid-loop pending user - # message lands before the preceding round's assistant/tool - # entries, producing chronologically-wrong session.messages - # on persist (user interposed between an assistant tool_call - # and its tool-result), which breaks OpenAI tool-call ordering - # invariants on the next turn's replay. + async def _run_tool_call_loop() -> None: + # Read/write the current session via ``_session_holder`` so this + # closure doesn't need to ``nonlocal session`` — pyright can't narrow + # the outer ``session: ChatSession | None`` through a nested scope, + # but the holder is typed non-optional after the preflight guard + # above. + try: + max_tool_rounds = config.agent_max_turns + async for loop_result in tool_call_loop( + messages=openai_messages, + tools=tools, + llm_call=_bound_llm_caller, + execute_tool=_bound_tool_executor, + update_conversation=_bound_conversation_updater, + max_iterations=max_tool_rounds, + last_iteration_message=_LAST_ITERATION_HINT, + ): + loop_result_holder[0] = loop_result + # Inject any messages the user queued while the turn was + # running. ``tool_call_loop`` mutates ``openai_messages`` + # in-place, so appending here means the model sees the new + # messages on its next LLM call. # - # Also persist any assistant text from text-only rounds (rounds - # with no tool calls, which ``_baseline_conversation_updater`` - # does NOT record in session_messages). If we only update - # ``_flushed_assistant_text_len`` without persisting the text, - # that text is silently lost: the finally block only appends - # assistant_text[_flushed_assistant_text_len:], so text generated - # before this drain never reaches session.messages. - recorded_text = "".join( - m.content or "" - for m in state.session_messages - if m.role == "assistant" + # IMPORTANT: skip when the loop has already finished (no + # more LLM calls are coming). ``tool_call_loop`` yields + # a final ``ToolCallLoopResult`` on both paths: + # - natural finish: ``finished_naturally=True`` + # - hit max_iterations: ``finished_naturally=False`` + # and ``iterations >= max_iterations`` + # In either case the loop is about to return on the next + # ``async for`` step, so draining here would silently + # lose the message (the user sees 202 but the model never + # reads the text). Those messages stay in the buffer and + # get picked up at the start of the next turn. + is_final_yield = ( + loop_result.finished_naturally + or loop_result.iterations >= max_tool_rounds ) - unflushed_text = state.assistant_text[ - state._flushed_assistant_text_len : - ] - text_only_text = ( - unflushed_text[len(recorded_text) :] - if unflushed_text.startswith(recorded_text) - else unflushed_text - ) - if text_only_text.strip(): - session.messages.append( - ChatMessage(role="assistant", content=text_only_text) + if is_final_yield: + continue + # Non-final yield: the next round may be the last one, so + # record where ``assistant_text`` ends now. If that next + # round hits the budget without adding any text, the outer + # fallback uses this anchor to detect a silent finish. + text_len_before_final_round[0] = len(state.assistant_text) + try: + pending = await drain_pending_messages(session_id) + except Exception: + logger.warning( + "[Baseline] mid-loop drain_pending_messages failed for " + "session %s", + session_id, + exc_info=True, ) - for _buffered in state.session_messages: - session.messages.append(_buffered) - state.session_messages.clear() - # Record how much assistant_text has been covered by the - # structured entries just flushed, so the finally block's - # final-text dedup doesn't re-append rounds already persisted. - state._flushed_assistant_text_len = len(state.assistant_text) + pending = [] + if pending: + # Flush any buffered assistant/tool messages from completed + # rounds into session.messages BEFORE appending the pending + # user message. ``_baseline_conversation_updater`` only + # records assistant+tool rounds into ``state.session_messages`` + # — they are normally batch-flushed in the finally block. + # Without this in-order flush, the mid-loop pending user + # message lands before the preceding round's assistant/tool + # entries, producing chronologically-wrong session.messages + # on persist (user interposed between an assistant tool_call + # and its tool-result), which breaks OpenAI tool-call ordering + # invariants on the next turn's replay. + # + # Also persist any assistant text from text-only rounds (rounds + # with no tool calls, which ``_baseline_conversation_updater`` + # does NOT record in session_messages). If we only update + # ``_flushed_assistant_text_len`` without persisting the text, + # that text is silently lost: the finally block only appends + # assistant_text[_flushed_assistant_text_len:], so text generated + # before this drain never reaches session.messages. + recorded_text = "".join( + m.content or "" + for m in state.session_messages + if m.role == "assistant" + ) + unflushed_text = state.assistant_text[ + state._flushed_assistant_text_len : + ] + text_only_text = ( + unflushed_text[len(recorded_text) :] + if unflushed_text.startswith(recorded_text) + else unflushed_text + ) + current_session = _session_holder[0] + if text_only_text.strip(): + current_session.messages.append( + ChatMessage(role="assistant", content=text_only_text) + ) + for _buffered in state.session_messages: + current_session.messages.append(_buffered) + state.session_messages.clear() + # Record how much assistant_text has been covered by the + # structured entries just flushed, so the finally block's + # final-text dedup doesn't re-append rounds already persisted. + state._flushed_assistant_text_len = len(state.assistant_text) - # Persist the assistant/tool flush BEFORE the pending append - # so a later pending-persist failure can roll back the - # pending rows without also discarding LLM output. - session = await persist_session_safe(session, "[Baseline]") - # ``upsert_chat_session`` may return a *new* ``ChatSession`` - # instance (e.g. when a concurrent title update has written a - # newer title to Redis, it returns ``session.model_copy``). - # Keep ``_session_holder`` in sync so subsequent tool rounds - # executed via ``_bound_tool_executor`` see the fresh session - # — any tool-side mutations on the stale object would be - # discarded when the new one is persisted in the ``finally``. - _session_holder[0] = session + # Persist the assistant/tool flush BEFORE the pending append + # so a later pending-persist failure can roll back the + # pending rows without also discarding LLM output. + current_session = await persist_session_safe( + current_session, "[Baseline]" + ) + # ``upsert_chat_session`` may return a *new* ``ChatSession`` + # instance (e.g. when a concurrent title update has written a + # newer title to Redis, it returns ``session.model_copy``). + # Keep ``_session_holder`` in sync so subsequent tool rounds + # executed via ``_bound_tool_executor`` see the fresh session + # — any tool-side mutations on the stale object would be + # discarded when the new one is persisted in the ``finally``. + _session_holder[0] = current_session - # ``format_pending_as_user_message`` embeds file attachments - # and context URL/page content into the content string so - # the in-session transcript is a faithful copy of what the - # model actually saw. We also mirror each push into - # ``openai_messages`` so the model's next LLM round sees it. - # - # Pre-compute the formatted dicts once so both the openai - # messages append and the content_of lookup inside the - # shared helper use the same string — and so ``on_rollback`` - # can trim ``openai_messages`` to the recorded anchor. - formatted_by_pm = { - id(pm): format_pending_as_user_message(pm) for pm in pending - } - _openai_anchor = len(openai_messages) - for pm in pending: - openai_messages.append(formatted_by_pm[id(pm)]) + # ``format_pending_as_user_message`` embeds file attachments + # and context URL/page content into the content string so + # the in-session transcript is a faithful copy of what the + # model actually saw. We also mirror each push into + # ``openai_messages`` so the model's next LLM round sees it. + # + # Pre-compute the formatted dicts once so both the openai + # messages append and the content_of lookup inside the + # shared helper use the same string — and so ``on_rollback`` + # can trim ``openai_messages`` to the recorded anchor. + formatted_by_pm = { + id(pm): format_pending_as_user_message(pm) for pm in pending + } + _openai_anchor = len(openai_messages) + for pm in pending: + openai_messages.append(formatted_by_pm[id(pm)]) - def _trim_openai_on_rollback(_session_anchor: int) -> None: - del openai_messages[_openai_anchor:] + def _trim_openai_on_rollback(_session_anchor: int) -> None: + del openai_messages[_openai_anchor:] - await persist_pending_as_user_rows( - session, - transcript_builder, - pending, - log_prefix="[Baseline]", - content_of=lambda pm: formatted_by_pm[id(pm)]["content"], - on_rollback=_trim_openai_on_rollback, + await persist_pending_as_user_rows( + current_session, + transcript_builder, + pending, + log_prefix="[Baseline]", + content_of=lambda pm: formatted_by_pm[id(pm)]["content"], + on_rollback=_trim_openai_on_rollback, + ) + finally: + # Always post the sentinel so the outer consumer exits — even if + # ``tool_call_loop`` raised. ``_baseline_llm_caller``'s own + # finally block has already pushed ``StreamReasoningEnd`` / + # ``StreamTextEnd`` / ``StreamFinishStep`` at this point, so the + # sentinel only terminates the consumer; it does not suppress + # any still-unflushed events. + state.pending_events.put_nowait(None) + + loop_task = asyncio.create_task(_run_tool_call_loop()) + try: + while True: + evt = await state.pending_events.get() + if evt is None: + break + yield evt + # Sentinel received — surface any exception the inner task hit. + await loop_task + loop_result = loop_result_holder[0] + # Budget was reached when iterations hit the configured cap. This + # covers both exit paths out of ``tool_call_loop``: + # - ``finished_naturally=True``: the last iteration ran with + # ``tools=[]`` and the model returned text (may be empty) + # - ``finished_naturally=False``: a non-compliant model still + # emitted tool calls despite the empty tool list, so the loop + # fell through the ``while`` guard + # Either way, we check the terminal round's text contribution — an + # empty one means the user got no explanation and we need to emit + # the fallback notice. + budget_reached = bool( + loop_result and loop_result.iterations >= config.agent_max_turns + ) + if budget_reached: + if loop_result and not loop_result.finished_naturally: + logger.warning( + "[Baseline] Hit %d-round tool budget without natural finish; " + "ending turn gracefully", + loop_result.iterations, ) - - if loop_result and not loop_result.finished_naturally: - limit_msg = ( - f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds " - "without a final response." + terminal_round_text = state.assistant_text[text_len_before_final_round[0] :] + fallback_events, fallback_text = _build_budget_exhausted_fallback_events( + terminal_round_text ) - logger.error("[Baseline] %s", limit_msg) - yield StreamError( - errorText=limit_msg, - code="baseline_tool_round_limit", - ) - + for evt in fallback_events: + yield evt + state.assistant_text += fallback_text except Exception as e: _stream_error = True error_msg = str(e) or type(e).__name__ logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True) - # Close any open text block. The llm_caller's finally block - # already appended StreamFinishStep to pending_events, so we must - # insert StreamTextEnd *before* StreamFinishStep to preserve the - # protocol ordering: - # StreamStartStep -> StreamTextStart -> ...deltas... -> - # StreamTextEnd -> StreamFinishStep - # Appending (or yielding directly) would place it after - # StreamFinishStep, violating the protocol. - if state.text_started: - # Find the last StreamFinishStep and insert before it. - insert_pos = len(state.pending_events) - for i in range(len(state.pending_events) - 1, -1, -1): - if isinstance(state.pending_events[i], StreamFinishStep): - insert_pos = i - break - state.pending_events.insert( - insert_pos, StreamTextEnd(id=state.text_block_id) - ) - # Drain pending events in correct order - for evt in state.pending_events: - yield evt - state.pending_events.clear() + # Drain any queued tail events (reasoning/text close + finish step) + # that ``_baseline_llm_caller``'s finally block pushed before the + # sentinel arrived — without this the frontend would be missing the + # matching end / finish parts for the partial round. + while not state.pending_events.empty(): + evt = state.pending_events.get_nowait() + if evt is not None: + yield evt yield StreamError(errorText=error_msg, code="baseline_error") # Still persist whatever we got finally: + # Cancel the inner task if we're unwinding early (client disconnect, + # unexpected error in the consumer) so it doesn't keep streaming + # tokens into a dead queue. + if loop_task is not None and not loop_task.done(): + loop_task.cancel() + try: + await loop_task + except (asyncio.CancelledError, Exception): + pass + # Re-sync the outer ``session`` binding in case the inner task + # reassigned it via a mid-loop ``persist_session_safe`` call. + session = _session_holder[0] + + # In-flight tool-call announcements are only meaningful for the + # current turn; clear at the top of the outer finally so the next + # turn starts with a clean scratch buffer even if one of the + # awaited cleanup steps below (usage persistence, session upsert, + # transcript upload) raises. The buffer is a process-local scratch + # set — if we leak it into the next turn the guide-read guard would + # observe a phantom in-flight call and skip its gate, so this must + # run unconditionally. + session.clear_inflight_tool_calls() + # Pending messages are drained atomically at turn start and # between tool rounds, so there's nothing to clear in finally. # Any message pushed after the final drain window stays in the @@ -1644,6 +2175,8 @@ async def stream_chat_completion_baseline( prompt_tokens=billed_prompt, completion_tokens=state.turn_completion_tokens, total_tokens=billed_prompt + state.turn_completion_tokens, + cache_read_tokens=state.turn_cache_read_tokens, + cache_creation_tokens=state.turn_cache_creation_tokens, ) yield StreamFinish() diff --git a/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py index a0e55d843f..1f3cfedb2b 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py @@ -10,11 +10,31 @@ import pytest from openai.types.chat import ChatCompletionToolParam from backend.copilot.baseline.service import ( + _BUDGET_EXHAUSTED_FALLBACK_TEXT, _baseline_conversation_updater, + _baseline_llm_caller, _BaselineStreamState, + _budget_exhausted_notice_text, + _build_budget_exhausted_fallback_events, + _build_cached_system_message, _compress_session_messages, + _extract_cache_creation_tokens, + _fresh_anthropic_caching_headers, + _fresh_ephemeral_cache_control, + _is_anthropic_model, + _mark_system_message_with_cache_control, + _mark_tools_with_cache_control, + _supports_prompt_cache_markers, ) from backend.copilot.model import ChatMessage +from backend.copilot.response_model import ( + StreamReasoningDelta, + StreamReasoningEnd, + StreamReasoningStart, + StreamTextDelta, + StreamTextEnd, + StreamTextStart, +) from backend.copilot.transcript_builder import TranscriptBuilder from backend.util.prompt import CompressResult from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult @@ -23,7 +43,10 @@ from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallRe class TestBaselineStreamState: def test_defaults(self): state = _BaselineStreamState() - assert state.pending_events == [] + # ``pending_events`` is an asyncio.Queue now (live SSE channel). + # The durable inspection view is ``emitted_events``. + assert state.pending_events.empty() + assert state.emitted_events == [] assert state.assistant_text == "" assert state.text_started is False assert state.turn_prompt_tokens == 0 @@ -574,37 +597,87 @@ class TestPrepareBaselineAttachments: assert blocks == [] +_COST_MISSING = object() + + +def _make_usage_chunk( + *, + prompt_tokens: int = 0, + completion_tokens: int = 0, + cost: float | str | None | object = _COST_MISSING, + cached_tokens: int | None = None, + cache_creation_input_tokens: int | None = None, +): + """Build a mock streaming chunk carrying usage (and optionally cost). + + Provider-specific fields (``cost`` on usage, ``cache_creation_input_tokens`` + on prompt_tokens_details) are set on ``model_extra`` because that's where + the baseline helper reads them from (typed ``CompletionUsage.model_extra`` + rather than ``getattr``). Pass ``cost=None`` to emit an explicit-null cost + key; omit ``cost`` entirely to leave the key absent. + """ + chunk = MagicMock() + chunk.choices = [] + chunk.usage = MagicMock() + chunk.usage.prompt_tokens = prompt_tokens + chunk.usage.completion_tokens = completion_tokens + usage_extras: dict[str, float | str | None] = {} + if cost is not _COST_MISSING: + usage_extras["cost"] = cost # type: ignore[assignment] + chunk.usage.model_extra = usage_extras + + if cached_tokens is not None or cache_creation_input_tokens is not None: + # Build a real ``PromptTokensDetails`` so ``getattr(ptd, + # "cache_write_tokens", None)`` returns ``None`` on this SDK version + # (rather than a truthy MagicMock attribute) and the extraction + # helper's typed-attr vs model_extra fallback resolves correctly. + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate({"cached_tokens": cached_tokens or 0}) + if cache_creation_input_tokens is not None: + if ptd.model_extra is None: + object.__setattr__(ptd, "__pydantic_extra__", {}) + assert ptd.model_extra is not None + ptd.model_extra["cache_creation_input_tokens"] = cache_creation_input_tokens + chunk.usage.prompt_tokens_details = ptd + else: + chunk.usage.prompt_tokens_details = None + + return chunk + + +def _make_stream_mock(*chunks): + """Build an async streaming response mock that yields *chunks* in order.""" + stream = MagicMock() + stream.close = AsyncMock() + + async def aiter(): + for c in chunks: + yield c + + stream.__aiter__ = lambda self: aiter() + return stream + + class TestBaselineCostExtraction: - """Tests for x-total-cost header extraction in _baseline_llm_caller.""" + """Tests for ``usage.cost`` extraction in ``_baseline_llm_caller``. + + Cost is read from the OpenRouter ``usage.cost`` field on the final + streaming chunk when the request body includes ``usage: {include: true}`` + (handled by the baseline service via ``extra_body``). + """ @pytest.mark.asyncio - async def test_cost_usd_extracted_from_response_header(self): - """state.cost_usd is set from x-total-cost header when present.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - + async def test_cost_usd_extracted_from_usage_chunk(self): + """state.cost_usd is set from chunk.usage.cost when present.""" state = _BaselineStreamState(model="gpt-4o-mini") - - # Build a mock raw httpx response with the cost header - mock_raw_response = MagicMock() - mock_raw_response.headers = {"x-total-cost": "0.0123"} - - # Build a mock async streaming response that yields no chunks but has - # a _response attribute pointing to the mock httpx response - mock_stream_response = MagicMock() - mock_stream_response._response = mock_raw_response - - async def empty_aiter(): - return - yield # make it an async generator - - mock_stream_response.__aiter__ = lambda self: empty_aiter() + chunk = _make_usage_chunk( + prompt_tokens=1000, completion_tokens=200, cost=0.0123 + ) mock_client = MagicMock() mock_client.chat.completions.create = AsyncMock( - return_value=mock_stream_response + return_value=_make_stream_mock(chunk) ) with patch( @@ -622,29 +695,14 @@ class TestBaselineCostExtraction: @pytest.mark.asyncio async def test_cost_usd_accumulates_across_calls(self): """cost_usd accumulates when _baseline_llm_caller is called multiple times.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - state = _BaselineStreamState(model="gpt-4o-mini") - def make_stream_mock(cost: str) -> MagicMock: - mock_raw = MagicMock() - mock_raw.headers = {"x-total-cost": cost} - mock_stream = MagicMock() - mock_stream._response = mock_raw - - async def empty_aiter(): - return - yield - - mock_stream.__aiter__ = lambda self: empty_aiter() - return mock_stream - mock_client = MagicMock() mock_client.chat.completions.create = AsyncMock( - side_effect=[make_stream_mock("0.01"), make_stream_mock("0.02")] + side_effect=[ + _make_stream_mock(_make_usage_chunk(prompt_tokens=500, cost=0.01)), + _make_stream_mock(_make_usage_chunk(prompt_tokens=600, cost=0.02)), + ] ) with patch( @@ -665,28 +723,64 @@ class TestBaselineCostExtraction: assert state.cost_usd == pytest.approx(0.03) @pytest.mark.asyncio - async def test_no_cost_when_header_absent(self): - """state.cost_usd remains None when response has no x-total-cost header.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - + async def test_cost_usd_accepts_string_value(self): + """OpenRouter may emit cost as a string — it should still parse.""" state = _BaselineStreamState(model="gpt-4o-mini") - - mock_raw = MagicMock() - mock_raw.headers = {} - mock_stream = MagicMock() - mock_stream._response = mock_raw - - async def empty_aiter(): - return - yield - - mock_stream.__aiter__ = lambda self: empty_aiter() + chunk = _make_usage_chunk(prompt_tokens=10, cost="0.005") mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_stream) + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.cost_usd == pytest.approx(0.005) + + @pytest.mark.asyncio + async def test_cost_usd_none_when_usage_cost_missing(self): + """state.cost_usd stays None when the usage chunk lacks a cost field.""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4") + chunk = _make_usage_chunk(prompt_tokens=1000, completion_tokens=500) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.cost_usd is None + # Token accumulators are still populated so the caller can log them. + assert state.turn_prompt_tokens == 1000 + assert state.turn_completion_tokens == 500 + + @pytest.mark.asyncio + async def test_invalid_cost_string_leaves_cost_none(self): + """A non-numeric cost value is rejected without raising.""" + state = _BaselineStreamState(model="gpt-4o-mini") + chunk = _make_usage_chunk(prompt_tokens=10, cost="not-a-number") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) with patch( "backend.copilot.baseline.service._get_openai_client", @@ -701,28 +795,73 @@ class TestBaselineCostExtraction: assert state.cost_usd is None @pytest.mark.asyncio - async def test_cost_extracted_even_when_stream_raises(self): - """cost_usd is captured in the finally block even when streaming fails.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, + async def test_negative_cost_is_ignored(self): + """Guard against negative cost values (shouldn't happen but be safe).""" + state = _BaselineStreamState(model="gpt-4o-mini") + chunk = _make_usage_chunk(prompt_tokens=10, cost=-0.01) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) ) + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.cost_usd is None + + @pytest.mark.asyncio + async def test_explicit_null_cost_is_logged_and_ignored(self, caplog): + """`{"cost": null}` is rejected and logged (not silently dropped).""" + state = _BaselineStreamState(model="openrouter/auto") + chunk = _make_usage_chunk(prompt_tokens=10, cost=None) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + with ( + patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ), + caplog.at_level("ERROR", logger="backend.copilot.baseline.service"), + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.cost_usd is None + assert any( + "usage.cost is present but null" in rec.message for rec in caplog.records + ) + + @pytest.mark.asyncio + async def test_cost_not_captured_when_stream_raises_mid_chunk(self): + """If the stream aborts before emitting the usage chunk there is no cost.""" state = _BaselineStreamState(model="gpt-4o-mini") - mock_raw = MagicMock() - mock_raw.headers = {"x-total-cost": "0.005"} - mock_stream = MagicMock() - mock_stream._response = mock_raw + stream = MagicMock() + stream.close = AsyncMock() async def failing_aiter(): raise RuntimeError("stream error") yield # make it an async generator - mock_stream.__aiter__ = lambda self: failing_aiter() + stream.__aiter__ = lambda self: failing_aiter() mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_stream) + mock_client.chat.completions.create = AsyncMock(return_value=stream) with ( patch( @@ -737,16 +876,12 @@ class TestBaselineCostExtraction: state=state, ) - assert state.cost_usd == pytest.approx(0.005) + # Stream aborted before yielding the usage chunk — cost stays None. + assert state.cost_usd is None @pytest.mark.asyncio async def test_no_cost_when_api_call_raises_before_stream(self): - """finally block is safe when response is None (API call failed before yielding).""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - + """The helper is safe when the create() call itself raises.""" state = _BaselineStreamState(model="gpt-4o-mini") mock_client = MagicMock() @@ -767,84 +902,23 @@ class TestBaselineCostExtraction: state=state, ) - # response was never assigned so cost extraction must not raise - assert state.cost_usd is None - - @pytest.mark.asyncio - async def test_no_cost_when_header_missing(self): - """cost_usd remains None when x-total-cost is absent.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - - state = _BaselineStreamState(model="anthropic/claude-sonnet-4") - - mock_raw = MagicMock() - mock_raw.headers = {} # no x-total-cost - mock_stream = MagicMock() - mock_stream._response = mock_raw - - mock_chunk = MagicMock() - mock_chunk.usage = MagicMock() - mock_chunk.usage.prompt_tokens = 1000 - mock_chunk.usage.completion_tokens = 500 - mock_chunk.usage.prompt_tokens_details = None - mock_chunk.choices = [] - - async def chunk_aiter(): - yield mock_chunk - - mock_stream.__aiter__ = lambda self: chunk_aiter() - - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_stream) - - with patch( - "backend.copilot.baseline.service._get_openai_client", - return_value=mock_client, - ): - await _baseline_llm_caller( - messages=[{"role": "user", "content": "hi"}], - tools=[], - state=state, - ) - assert state.cost_usd is None @pytest.mark.asyncio async def test_cache_tokens_extracted_from_usage_details(self): """cache tokens are extracted from prompt_tokens_details.cached_tokens.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, + state = _BaselineStreamState(model="openai/gpt-4o") + chunk = _make_usage_chunk( + prompt_tokens=1000, + completion_tokens=200, + cost=0.01, + cached_tokens=800, ) - state = _BaselineStreamState(model="openai/gpt-4o") - - mock_raw = MagicMock() - mock_raw.headers = {"x-total-cost": "0.01"} - mock_stream = MagicMock() - mock_stream._response = mock_raw - - # Create a chunk with prompt_tokens_details - mock_ptd = MagicMock() - mock_ptd.cached_tokens = 800 - - mock_chunk = MagicMock() - mock_chunk.usage = MagicMock() - mock_chunk.usage.prompt_tokens = 1000 - mock_chunk.usage.completion_tokens = 200 - mock_chunk.usage.prompt_tokens_details = mock_ptd - mock_chunk.choices = [] - - async def chunk_aiter(): - yield mock_chunk - - mock_stream.__aiter__ = lambda self: chunk_aiter() - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_stream) + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) with patch( "backend.copilot.baseline.service._get_openai_client", @@ -861,37 +935,20 @@ class TestBaselineCostExtraction: @pytest.mark.asyncio async def test_cache_creation_tokens_extracted_from_usage_details(self): - """cache_creation_tokens are extracted from prompt_tokens_details.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, + """cache_creation_input_tokens is extracted from prompt_tokens_details.""" + state = _BaselineStreamState(model="openai/gpt-4o") + chunk = _make_usage_chunk( + prompt_tokens=1000, + completion_tokens=200, + cost=0.01, + cached_tokens=0, + cache_creation_input_tokens=500, ) - state = _BaselineStreamState(model="openai/gpt-4o") - - mock_raw = MagicMock() - mock_raw.headers = {"x-total-cost": "0.01"} - mock_stream = MagicMock() - mock_stream._response = mock_raw - - mock_ptd = MagicMock() - mock_ptd.cached_tokens = 0 - mock_ptd.cache_creation_input_tokens = 500 - - mock_chunk = MagicMock() - mock_chunk.usage = MagicMock() - mock_chunk.usage.prompt_tokens = 1000 - mock_chunk.usage.completion_tokens = 200 - mock_chunk.usage.prompt_tokens_details = mock_ptd - mock_chunk.choices = [] - - async def chunk_aiter(): - yield mock_chunk - - mock_stream.__aiter__ = lambda self: chunk_aiter() - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_stream) + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) with patch( "backend.copilot.baseline.service._get_openai_client", @@ -908,37 +965,17 @@ class TestBaselineCostExtraction: @pytest.mark.asyncio async def test_token_accumulators_track_across_multiple_calls(self): """Token accumulators grow correctly across multiple _baseline_llm_caller calls.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - state = _BaselineStreamState(model="anthropic/claude-sonnet-4") - def make_stream(prompt_tokens: int, completion_tokens: int): - mock_raw = MagicMock() - mock_raw.headers = {} # no x-total-cost - mock_stream = MagicMock() - mock_stream._response = mock_raw - - mock_chunk = MagicMock() - mock_chunk.usage = MagicMock() - mock_chunk.usage.prompt_tokens = prompt_tokens - mock_chunk.usage.completion_tokens = completion_tokens - mock_chunk.usage.prompt_tokens_details = None - mock_chunk.choices = [] - - async def chunk_aiter(): - yield mock_chunk - - mock_stream.__aiter__ = lambda self: chunk_aiter() - return mock_stream - mock_client = MagicMock() mock_client.chat.completions.create = AsyncMock( side_effect=[ - make_stream(1000, 200), - make_stream(1100, 300), + _make_stream_mock( + _make_usage_chunk(prompt_tokens=1000, completion_tokens=200) + ), + _make_stream_mock( + _make_usage_chunk(prompt_tokens=1100, completion_tokens=300) + ), ] ) @@ -957,45 +994,33 @@ class TestBaselineCostExtraction: state=state, ) - # No x-total-cost header and empty pricing table -- cost_usd remains None + # No usage.cost on either chunk → cost stays None, tokens still accumulate. assert state.cost_usd is None - # Accumulators hold all tokens across both turns assert state.turn_prompt_tokens == 2100 assert state.turn_completion_tokens == 500 + @pytest.mark.parametrize( + "tools", + [ + pytest.param([], id="no_tools"), + pytest.param([_make_tool("search")], id="with_tools"), + ], + ) @pytest.mark.asyncio - async def test_cost_usd_remains_none_when_header_missing(self): - """cost_usd stays None when x-total-cost header is absent. + async def test_baseline_requests_usage_include_extra_body( + self, tools: list[ChatCompletionToolParam] + ): + """The baseline call must pass extra_body={'usage': {'include': True}}. - Token counts are still tracked; persist_and_record_usage handles - the None cost by falling back to tracking_type='tokens'. + This guards the contract with OpenRouter that triggers inclusion of + the authoritative cost on the final usage chunk. Without it the + rate-limit counter stays at zero. Exercise both the no-tools and + tool-calling branches so a regression in either path trips the test. """ - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - - state = _BaselineStreamState(model="anthropic/claude-sonnet-4") - - mock_raw = MagicMock() - mock_raw.headers = {} # no x-total-cost - mock_stream = MagicMock() - mock_stream._response = mock_raw - - mock_chunk = MagicMock() - mock_chunk.usage = MagicMock() - mock_chunk.usage.prompt_tokens = 1000 - mock_chunk.usage.completion_tokens = 500 - mock_chunk.usage.prompt_tokens_details = None - mock_chunk.choices = [] - - async def chunk_aiter(): - yield mock_chunk - - mock_stream.__aiter__ = lambda self: chunk_aiter() - + state = _BaselineStreamState(model="gpt-4o-mini") + create_mock = AsyncMock(return_value=_make_stream_mock()) mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_stream) + mock_client.chat.completions.create = create_mock with patch( "backend.copilot.baseline.service._get_openai_client", @@ -1003,13 +1028,15 @@ class TestBaselineCostExtraction: ): await _baseline_llm_caller( messages=[{"role": "user", "content": "hi"}], - tools=[], + tools=tools, state=state, ) - assert state.cost_usd is None - assert state.turn_prompt_tokens == 1000 - assert state.turn_completion_tokens == 500 + create_mock.assert_awaited_once() + await_args = create_mock.await_args + assert await_args is not None + assert await_args.kwargs["extra_body"] == {"usage": {"include": True}} + assert await_args.kwargs["stream_options"] == {"include_usage": True} class TestMidLoopPendingFlushOrdering: @@ -1211,3 +1238,908 @@ class TestMidLoopPendingFlushOrdering: assert assistant_msgs[1].tool_calls is None # Crucially: only 2 assistant messages, not 3 (no duplicate) assert len(assistant_msgs) == 2 + + +class TestBuilderContextSplit: + """Cross-helper composition: the guide must land in the system prompt via + ``build_builder_system_prompt_suffix`` and NOT in the per-turn user prefix + via ``build_builder_context_turn_prefix``. + + The baseline service composes these two blocks on each turn, so a drift + here (guide leaking into both, or missing from both) would kill Claude's + prompt-cache hit rate for builder sessions. + """ + + @pytest.mark.asyncio + async def test_guide_lives_in_system_prompt_not_user_message(self): + from backend.copilot.builder_context import ( + BUILDER_CONTEXT_TAG, + BUILDER_SESSION_TAG, + build_builder_context_turn_prefix, + build_builder_system_prompt_suffix, + ) + from backend.copilot.model import ChatSession + + session = MagicMock(spec=ChatSession) + session.session_id = "s" + session.metadata = MagicMock() + session.metadata.builder_graph_id = "graph-1" + + agent_json = { + "id": "graph-1", + "name": "Demo", + "version": 7, + "nodes": [ + { + "id": "n1", + "block_id": "block-A", + "input_default": {"name": "Input"}, + "metadata": {}, + } + ], + "links": [], + } + guide_body = "# UNIQUE_GUIDE_MARKER body" + with ( + patch( + "backend.copilot.builder_context.get_agent_as_json", + new=AsyncMock(return_value=agent_json), + ), + patch( + "backend.copilot.builder_context._load_guide", + return_value=guide_body, + ), + ): + suffix = await build_builder_system_prompt_suffix(session) + prefix = await build_builder_context_turn_prefix(session, "user-1") + + # System prompt suffix carries and the guide. + assert f"<{BUILDER_SESSION_TAG}>" in suffix + assert guide_body in suffix + # Dynamic bits must NOT be in the suffix — otherwise renames and + # cross-graph sessions invalidate Claude's prompt cache. + assert "graph-1" not in suffix + assert "Demo" not in suffix + + # Per-turn prefix carries with the full live + # snapshot (id, name, version, nodes) but NEVER the guide. + assert f"<{BUILDER_CONTEXT_TAG}>" in prefix + assert 'id="graph-1"' in prefix + assert 'name="Demo"' in prefix + assert 'version="7"' in prefix + assert guide_body not in prefix + assert "" not in prefix + + # Guide appears in the combined on-the-wire payload exactly ONCE. + combined = suffix + "\n\n" + prefix + assert combined.count(guide_body) == 1 + + +class TestApplyPromptCacheMarkers: + """Tests for _apply_prompt_cache_markers — Anthropic ephemeral + cache_control markers on baseline OpenRouter requests.""" + + def test_system_message_converted_to_content_blocks(self): + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hello"}, + ] + + cached_messages = _mark_system_message_with_cache_control(messages) + + assert cached_messages[0]["role"] == "system" + assert cached_messages[0]["content"] == [ + { + "type": "text", + "text": "You are helpful.", + "cache_control": {"type": "ephemeral", "ttl": "1h"}, + } + ] + # User message must be untouched. + assert cached_messages[1] == {"role": "user", "content": "hello"} + + def test_system_message_preserves_unknown_fields(self): + # Future-proofing: a system message with extra keys (e.g. "name") must + # keep them after the content-blocks conversion. + messages = [ + {"role": "system", "content": "sys", "name": "developer"}, + ] + + cached_messages = _mark_system_message_with_cache_control(messages) + + assert cached_messages[0]["name"] == "developer" + assert cached_messages[0]["role"] == "system" + + def test_last_tool_gets_cache_control(self): + tools = [ + {"type": "function", "function": {"name": "a"}}, + {"type": "function", "function": {"name": "b"}}, + ] + + cached_tools = _mark_tools_with_cache_control(tools) + + assert "cache_control" not in cached_tools[0] + assert cached_tools[-1]["cache_control"] == { + "type": "ephemeral", + "ttl": "1h", + } + # Last tool's other fields preserved. + assert cached_tools[-1]["function"] == {"name": "b"} + + def test_does_not_mutate_input(self): + messages = [{"role": "system", "content": "sys"}] + tools = [{"type": "function", "function": {"name": "a"}}] + + _mark_system_message_with_cache_control(messages) + _mark_tools_with_cache_control(tools) + + assert messages == [{"role": "system", "content": "sys"}] + assert tools == [{"type": "function", "function": {"name": "a"}}] + + def test_no_system_message_safe(self): + messages = [{"role": "user", "content": "hi"}] + cached_messages = _mark_system_message_with_cache_control(messages) + assert cached_messages == messages + + def test_empty_tools_safe(self): + assert _mark_tools_with_cache_control([]) == [] + + def test_non_string_system_content_left_untouched(self): + # If the content is already a list of blocks (e.g. caller pre-marked), + # the helper must not overwrite it. + pre_marked = [ + { + "type": "text", + "text": "sys", + "cache_control": {"type": "ephemeral", "ttl": "1h"}, + } + ] + messages = [{"role": "system", "content": pre_marked}] + cached_messages = _mark_system_message_with_cache_control(messages) + assert cached_messages[0]["content"] == pre_marked + + def test_is_anthropic_model_matches_claude_and_anthropic_prefix(self): + assert _is_anthropic_model("anthropic/claude-sonnet-4-6") + assert _is_anthropic_model("claude-3-5-sonnet-20241022") + assert _is_anthropic_model("anthropic.claude-3-5-sonnet-20241022-v2:0") + assert _is_anthropic_model("ANTHROPIC/Claude-Opus") # case insensitive + + def test_is_anthropic_model_rejects_other_providers(self): + assert not _is_anthropic_model("openai/gpt-4o") + assert not _is_anthropic_model("openai/gpt-5") + assert not _is_anthropic_model("google/gemini-2.5-pro") + assert not _is_anthropic_model("xai/grok-4") + assert not _is_anthropic_model("meta-llama/llama-3.3-70b-instruct") + + def test_is_anthropic_model_rejects_kimi_routes(self): + """Regression guard: Kimi K2.6 is a reasoning route (reasoning + extra_body is sent) but NOT an Anthropic route — Moonshot does + its own auto prompt caching, so ``cache_control`` markers must + NOT be applied. OpenRouter silently drops them today, but if + they ever start failing fast we'd want the gate tight.""" + assert not _is_anthropic_model("moonshotai/kimi-k2.6") + assert not _is_anthropic_model("moonshotai/kimi-k2-thinking") + assert not _is_anthropic_model("kimi-k2-instruct") + + def test_cache_control_uses_configured_ttl(self, monkeypatch): + """TTL comes from ChatConfig.baseline_prompt_cache_ttl — defaults + to 1h so the static prefix (system + tools) stays warm across + workspace users past the 5-min default window.""" + from backend.copilot.baseline import service as bsvc + + assert bsvc.config.baseline_prompt_cache_ttl == "1h" + cc = bsvc._fresh_ephemeral_cache_control() + assert cc == {"type": "ephemeral", "ttl": "1h"} + monkeypatch.setattr(bsvc.config, "baseline_prompt_cache_ttl", "5m") + assert bsvc._fresh_ephemeral_cache_control() == { + "type": "ephemeral", + "ttl": "5m", + } + + def test_fresh_helpers_return_distinct_objects(self): + """Regression guard: the `_fresh_*` helpers must return a NEW dict + on every call. A future refactor returning a module-level constant + would silently reintroduce the shared-mutable-state bug flagged + during earlier review cycles.""" + assert _fresh_ephemeral_cache_control() is not _fresh_ephemeral_cache_control() + assert ( + _fresh_anthropic_caching_headers() is not _fresh_anthropic_caching_headers() + ) + + def test_extract_cache_creation_tokens_openrouter_typed_attr(self): + """Newer ``openai-python`` declares ``cache_write_tokens`` as a + typed attribute on ``PromptTokensDetails`` — it no longer lands in + ``model_extra``. Verified empirically against the production + openai==1.113 installed in this venv: OpenRouter streaming + response populates ``ptd.cache_write_tokens`` directly while + ``ptd.model_extra`` is ``{}``. + """ + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate( + { + "audio_tokens": 0, + "cached_tokens": 0, + "cache_write_tokens": 4432, + "video_tokens": 0, + } + ) + assert getattr(ptd, "cache_write_tokens", None) == 4432 + assert _extract_cache_creation_tokens(ptd) == 4432 + + def test_extract_cache_creation_tokens_openrouter_model_extra(self): + """Older SDKs that don't yet declare ``cache_write_tokens`` as a + typed field leave it in ``model_extra`` — the helper must still + find it there.""" + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate({"cached_tokens": 0}) + # Force the value into model_extra (simulates the old SDK shape + # where the field wasn't typed yet). + if ptd.model_extra is None: + # Pydantic v2 sometimes exposes __pydantic_extra__ as None when + # extras are disabled; initialise to a dict to mutate safely. + object.__setattr__(ptd, "__pydantic_extra__", {}) + assert ptd.model_extra is not None + ptd.model_extra["cache_write_tokens"] = 7777 + assert _extract_cache_creation_tokens(ptd) == 7777 + + def test_extract_cache_creation_tokens_anthropic_native_field(self): + """Direct Anthropic API uses ``cache_creation_input_tokens`` — + falls through as the final path when neither + ``cache_write_tokens`` typed attr nor model_extra entry exists.""" + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate({"cached_tokens": 0}) + if ptd.model_extra is None: + object.__setattr__(ptd, "__pydantic_extra__", {}) + assert ptd.model_extra is not None + ptd.model_extra["cache_creation_input_tokens"] = 2048 + assert _extract_cache_creation_tokens(ptd) == 2048 + + def test_extract_cache_creation_tokens_absent(self): + """Neither provider field present → 0 (non-Anthropic routes or + cache-miss responses).""" + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate({"cached_tokens": 0}) + assert _extract_cache_creation_tokens(ptd) == 0 + + def test_build_cached_system_message_applies_cache_control(self): + """The single-message helper wraps the string content in a text block + with an ephemeral cache_control marker.""" + out = _build_cached_system_message({"role": "system", "content": "hi"}) + assert out["role"] == "system" + assert out["content"] == [ + { + "type": "text", + "text": "hi", + "cache_control": {"type": "ephemeral", "ttl": "1h"}, + } + ] + + def test_build_cached_system_message_preserves_extra_fields(self): + """Unknown keys (e.g. ``name``) survive the transformation.""" + out = _build_cached_system_message( + {"role": "system", "content": "sys", "name": "dev"} + ) + assert out["name"] == "dev" + assert out["role"] == "system" + + def test_build_cached_system_message_non_string_passthrough(self): + """Pre-marked list content is returned as-is (shallow-copied).""" + pre_marked = [ + { + "type": "text", + "text": "sys", + "cache_control": {"type": "ephemeral", "ttl": "1h"}, + } + ] + out = _build_cached_system_message({"role": "system", "content": pre_marked}) + assert out["content"] is pre_marked + + @pytest.mark.asyncio + async def test_baseline_llm_caller_memoises_cached_system_message(self): + """The cached system dict is built once and reused across rounds. + + Guards against the perf regression where the entire (growing) + ``messages`` list was copied on every tool-call iteration just to + mark the static system prompt. + """ + state = _BaselineStreamState(model="anthropic/claude-sonnet-4") + chunk = _make_usage_chunk(prompt_tokens=10, completion_tokens=5) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + side_effect=[_make_stream_mock(chunk), _make_stream_mock(chunk)] + ) + + messages: list[dict] = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hi"}, + ] + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller(messages=messages, tools=[], state=state) + first_cached = state.cached_system_message + assert first_cached is not None + # Simulate the tool-call loop growing ``messages`` between rounds. + messages.append({"role": "assistant", "content": "ok"}) + messages.append({"role": "user", "content": "follow up"}) + await _baseline_llm_caller(messages=messages, tools=[], state=state) + + # Same dict instance reused — not rebuilt per round. + assert state.cached_system_message is first_cached + + # Second call's first message is the memoised system dict (not a new copy). + second_call_messages = mock_client.chat.completions.create.call_args_list[1][1][ + "messages" + ] + assert second_call_messages[0] is first_cached + # And the tail messages were spliced in, not re-copied. + assert second_call_messages[1] is messages[1] + assert second_call_messages[-1] is messages[-1] + + @pytest.mark.asyncio + async def test_baseline_llm_caller_skips_memoisation_for_non_anthropic(self): + """Non-Anthropic routes pass messages through unmodified — no cache + dict is built, no list splicing happens.""" + state = _BaselineStreamState(model="openai/gpt-4o") + chunk = _make_usage_chunk(prompt_tokens=10, completion_tokens=5) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + messages: list[dict] = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "hi"}, + ] + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller(messages=messages, tools=[], state=state) + + assert state.cached_system_message is None + # The exact same list object reaches the provider (no copy needed). + call_messages = mock_client.chat.completions.create.call_args[1]["messages"] + assert call_messages is messages + + +def _make_delta_chunk( + *, + content: str | None = None, + reasoning: str | None = None, + reasoning_details: list | None = None, + reasoning_content: str | None = None, + tool_calls: list | None = None, +): + """Build a streaming chunk with a configurable ``delta`` payload. + + The ``delta`` is a real ``ChoiceDelta`` pydantic instance so OpenRouter + extension fields land on ``delta.model_extra`` — which is how + :class:`OpenRouterDeltaExtension` reads them in production. Using a + raw ``MagicMock`` here would leave ``model_extra`` unset and silently + skip the reasoning parser. ``tool_calls`` (when provided) must be + ``MagicMock`` entries compatible with the service's streaming loop; + they're set on the delta via ``object.__setattr__`` because pydantic + would otherwise reject the non-schema types. + """ + from openai.types.chat.chat_completion_chunk import ChoiceDelta + + payload: dict = {"role": "assistant"} + if content is not None: + payload["content"] = content + if reasoning is not None: + payload["reasoning"] = reasoning + if reasoning_content is not None: + payload["reasoning_content"] = reasoning_content + if reasoning_details is not None: + payload["reasoning_details"] = reasoning_details + delta = ChoiceDelta.model_validate(payload) + # ChoiceDelta's tool_calls schema expects OpenAI-typed entries; bypass + # validation so tests can use MagicMocks that mimic the streaming shape. + if tool_calls is not None: + object.__setattr__(delta, "tool_calls", tool_calls) + + chunk = MagicMock() + chunk.usage = None + choice = MagicMock() + choice.delta = delta + chunk.choices = [choice] + return chunk + + +def _make_tool_call_delta(*, index: int, call_id: str, name: str, arguments: str): + """Build a ``delta.tool_calls[i]`` entry for streaming tool-use.""" + tc = MagicMock() + tc.index = index + tc.id = call_id + function = MagicMock() + function.name = name + function.arguments = arguments + tc.function = function + return tc + + +class TestBaselineReasoningStreaming: + """End-to-end reasoning event emission through ``_baseline_llm_caller``.""" + + @pytest.mark.asyncio + async def test_reasoning_then_text_emits_paired_events(self): + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + chunks = [ + _make_delta_chunk(reasoning="thinking..."), + _make_delta_chunk(reasoning=" more"), + _make_delta_chunk(content="final answer"), + ] + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(*chunks) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + types = [type(e).__name__ for e in state.emitted_events] + assert "StreamReasoningStart" in types + assert "StreamReasoningDelta" in types + assert "StreamReasoningEnd" in types + + # Reasoning must close before text opens — AI SDK v5 rejects + # interleaved reasoning / text parts. + reason_end = types.index("StreamReasoningEnd") + text_start = types.index("StreamTextStart") + assert reason_end < text_start + + # All reasoning deltas share a single block id; the text block uses + # a fresh id after the reasoning-end rotation. + reasoning_ids = { + e.id + for e in state.emitted_events + if isinstance( + e, (StreamReasoningStart, StreamReasoningDelta, StreamReasoningEnd) + ) + } + text_ids = { + e.id + for e in state.emitted_events + if isinstance(e, (StreamTextStart, StreamTextDelta, StreamTextEnd)) + } + assert len(reasoning_ids) == 1 + assert len(text_ids) == 1 + assert reasoning_ids.isdisjoint(text_ids) + + combined = "".join( + e.delta for e in state.emitted_events if isinstance(e, StreamReasoningDelta) + ) + assert combined == "thinking... more" + + @pytest.mark.asyncio + async def test_reasoning_then_tool_call_closes_reasoning_first(self): + """A tool_call arriving mid-reasoning must close the reasoning block + before the tool-use is flushed — AI SDK v5 treats reasoning and + tool-use as distinct UI parts and rejects interleaving.""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + chunks = [ + _make_delta_chunk(reasoning="deliberating..."), + _make_delta_chunk( + tool_calls=[ + _make_tool_call_delta( + index=0, + call_id="call_1", + name="search", + arguments='{"q":"x"}', + ) + ], + ), + ] + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(*chunks) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + response = await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + # A reasoning-end must have been emitted — this is the tool_calls + # branch's responsibility, not the stream-end cleanup. + types = [type(e).__name__ for e in state.emitted_events] + assert "StreamReasoningStart" in types + assert "StreamReasoningEnd" in types + + # The tool_call was collected — confirms the tool-use path executed + # after reasoning closed (rather than silently dropping the tool). + assert len(response.tool_calls) == 1 + assert response.tool_calls[0].name == "search" + + # No text events — this stream had no content deltas. + assert "StreamTextStart" not in types + + @pytest.mark.asyncio + async def test_reasoning_closed_on_mid_stream_exception(self): + """Regression guard: an exception during the streaming loop must + still emit ``StreamReasoningEnd`` (and ``StreamTextEnd`` when a + text block is open) before ``StreamFinishStep`` — the frontend + collapse relies on matched start/end pairs, and the outer handler + no longer patches these after-the-fact.""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + async def failing_stream(): + yield _make_delta_chunk(reasoning="thinking...") + raise RuntimeError("boom") + + stream = MagicMock() + stream.close = AsyncMock() + stream.__aiter__ = lambda self: failing_stream() + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=stream) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + with pytest.raises(RuntimeError): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + types = [type(e).__name__ for e in state.emitted_events] + # The reasoning block was opened, the exception fired, and the + # finally block must have closed it before emitting the finish + # step. + assert "StreamReasoningStart" in types + assert "StreamReasoningEnd" in types + assert "StreamFinishStep" in types + assert types.index("StreamReasoningEnd") < types.index("StreamFinishStep") + # Emitter is reset so a retried round starts with fresh ids. + assert state.reasoning_emitter.is_open is False + + @pytest.mark.asyncio + async def test_reasoning_param_sent_on_anthropic_routes(self): + """Anthropic route gets ``reasoning.max_tokens`` on the request.""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock() + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + extra_body = mock_client.chat.completions.create.call_args[1]["extra_body"] + assert "reasoning" in extra_body + assert extra_body["reasoning"]["max_tokens"] > 0 + + @pytest.mark.asyncio + async def test_reasoning_param_absent_on_non_anthropic_routes(self): + """Non-reasoning routes (e.g. OpenAI) must not receive ``reasoning``.""" + state = _BaselineStreamState(model="openai/gpt-4o") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock() + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + extra_body = mock_client.chat.completions.create.call_args[1]["extra_body"] + assert "reasoning" not in extra_body + + @pytest.mark.asyncio + async def test_kimi_route_sends_reasoning_and_cache_control(self): + """Kimi K2.6 (Moonshot via OpenRouter's Anthropic-compat endpoint) + accepts ``cache_control: {type: ephemeral}`` on the system block + and the last tool — the endpoint honours the marker and lifts + cache hit rate on continuation turns from near-zero (Moonshot's + auto-caching drifts) to the Anthropic ~60-95% ballpark. The + ``anthropic-beta`` header stays off because Moonshot doesn't need + it; OpenRouter would strip the unknown header anyway.""" + state = _BaselineStreamState(model="moonshotai/kimi-k2.6") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock() + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "hi"}, + ], + tools=[ + { + "type": "function", + "function": {"name": "echo", "parameters": {}}, + } + ], + state=state, + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + extra_body = call_kwargs["extra_body"] + # Reasoning param on — the whole point of picking Kimi is the + # cheap-but-still-reasoning-capable path. + assert "reasoning" in extra_body + assert extra_body["reasoning"]["max_tokens"] > 0 + # No ``anthropic-beta`` header — that beta is specifically for + # native Anthropic endpoints; Moonshot's shim accepts + # ``cache_control`` without it, and sending it would be wasted + # bytes (OR strips it before forwarding to Moonshot). + assert "extra_headers" not in call_kwargs or not call_kwargs.get( + "extra_headers" + ) + # System block MUST carry ``cache_control`` so Moonshot's cache + # breakpoint is honoured. The cached system-message builder + # emits list-shape content with the marker on the first (and + # only) block — assert on that shape. + sys_msg = call_kwargs["messages"][0] + sys_content = sys_msg.get("content") + assert isinstance( + sys_content, list + ), "Cached system message should be a list-shape content block" + assert any( + "cache_control" in block for block in sys_content if isinstance(block, dict) + ), "Kimi system message should now carry cache_control markers" + # Tool-level cache marking is applied by ``stream_chat_completion_baseline`` + # (see ``_mark_tools_with_cache_control``) before tools reach + # ``_baseline_llm_caller``, so this unit test doesn't exercise + # that path — covered by the outer integration test. + + @pytest.mark.asyncio + async def test_reasoning_only_stream_still_closes_block(self): + """Regression: a stream with only reasoning (no text, no tool_call) + must still emit a matching ``reasoning-end`` at stream close so the + frontend Reasoning collapse finalises. Exercised here against + ``_baseline_llm_caller`` to cover the emitter's integration with + the finally-block, not just the unit emitter in reasoning_test.py. + """ + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock( + _make_delta_chunk(reasoning="just thinking"), + ) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + types = [type(e).__name__ for e in state.emitted_events] + assert "StreamReasoningStart" in types + assert "StreamReasoningEnd" in types + # No text was produced — no text events should be emitted. + assert "StreamTextStart" not in types + assert "StreamTextDelta" not in types + + @pytest.mark.asyncio + async def test_reasoning_param_suppressed_when_thinking_tokens_zero(self): + """Operator kill switch: setting ``claude_agent_max_thinking_tokens`` + to 0 removes the ``reasoning`` fragment from ``extra_body`` even on + an Anthropic route. Restores the zero-disables behaviour the old + ``baseline_reasoning_max_tokens`` config used to provide.""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock() + ) + + with ( + patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ), + patch( + "backend.copilot.baseline.service.config.claude_agent_max_thinking_tokens", + 0, + ), + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + extra_body = mock_client.chat.completions.create.call_args[1]["extra_body"] + assert "reasoning" not in extra_body + + @pytest.mark.asyncio + async def test_reasoning_persists_to_state_session_messages(self): + """Integration guard: ``_BaselineStreamState.__post_init__`` wires + the emitter to ``state.session_messages``, so reasoning deltas + flowing through ``_baseline_llm_caller`` must produce a + ``role="reasoning"`` row on the state's session list. Catches + regressions where the wiring silently breaks (e.g. a refactor + passes the wrong list reference).""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock( + _make_delta_chunk(reasoning="first "), + _make_delta_chunk(reasoning="thought"), + _make_delta_chunk(content="answer"), + ) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + reasoning_rows = [m for m in state.session_messages if m.role == "reasoning"] + assert len(reasoning_rows) == 1 + assert reasoning_rows[0].content == "first thought" + + +class TestSupportsPromptCacheMarkers: + """``_supports_prompt_cache_markers`` is the widened gate for + emitting ``cache_control`` markers on message content. It's a + superset of ``_is_anthropic_model`` that ALSO admits Moonshot + (whose Anthropic-compat endpoint honours the marker) while keeping + the False answer for OpenAI / Grok / Gemini (which 400 on the + unknown field).""" + + @pytest.mark.parametrize( + "model", + [ + "anthropic/claude-sonnet-4-6", + "claude-3-5-sonnet-20241022", + "anthropic.claude-3-5-sonnet", + "ANTHROPIC/Claude-Opus", + ], + ) + def test_anthropic_routes_are_supported(self, model): + assert _supports_prompt_cache_markers(model) is True + + @pytest.mark.parametrize( + "model", + [ + "moonshotai/kimi-k2.6", + "moonshotai/kimi-k2-thinking", + "moonshotai/kimi-k2.5", + "moonshotai/kimi-k3.0", # future SKU + ], + ) + def test_moonshot_routes_are_supported(self, model): + """The whole reason this predicate exists — Moonshot must be + True even though ``_is_anthropic_model`` is False for it.""" + assert _supports_prompt_cache_markers(model) is True + # Verify this is strictly wider than the anthropic-only check. + assert _is_anthropic_model(model) is False + + @pytest.mark.parametrize( + "model", + [ + "openai/gpt-4o", + "google/gemini-2.5-pro", + "xai/grok-4", + "meta-llama/llama-3.3-70b-instruct", + "deepseek/deepseek-v3", + ], + ) + def test_other_providers_still_rejected(self, model): + """Regression guard: OpenAI/Grok/Gemini still 400 on + ``cache_control``, so the widened gate must keep them out.""" + assert _supports_prompt_cache_markers(model) is False + + +class TestBudgetExhaustedNoticeText: + """Tests for the fallback-notice decision used when the tool-round + budget is exhausted without a natural finish.""" + + def test_empty_text_returns_fallback(self): + assert _budget_exhausted_notice_text("") == _BUDGET_EXHAUSTED_FALLBACK_TEXT + + def test_whitespace_only_returns_fallback(self): + """A string of only whitespace is still "no visible response".""" + assert ( + _budget_exhausted_notice_text(" \n\t ") + == _BUDGET_EXHAUSTED_FALLBACK_TEXT + ) + + def test_non_empty_text_returns_none(self): + """When the model already summarised, stay quiet — no extra notice.""" + assert _budget_exhausted_notice_text("Here is what I did...") is None + + def test_fallback_text_is_user_facing(self): + """Guard against accidentally shipping an empty / internal string.""" + assert _BUDGET_EXHAUSTED_FALLBACK_TEXT.strip() + assert "tool-call budget" in _BUDGET_EXHAUSTED_FALLBACK_TEXT + assert "follow-up" in _BUDGET_EXHAUSTED_FALLBACK_TEXT + + +class TestBuildBudgetExhaustedFallbackEvents: + """Tests for the helper that produces the stream events + text mutation + for a budget-exhausted turn with no terminal-round text.""" + + def test_empty_terminal_text_emits_three_events(self): + events, to_append = _build_budget_exhausted_fallback_events("") + assert to_append == _BUDGET_EXHAUSTED_FALLBACK_TEXT + assert len(events) == 3 + assert isinstance(events[0], StreamTextStart) + assert isinstance(events[1], StreamTextDelta) + assert isinstance(events[2], StreamTextEnd) + # All three events share the same block id so the frontend groups + # them into a single text bubble. + assert events[0].id == events[1].id == events[2].id + # The delta carries the user-facing notice verbatim. + assert events[1].delta == _BUDGET_EXHAUSTED_FALLBACK_TEXT + + def test_non_empty_terminal_text_returns_empty(self): + """Model already produced visible final text → no fallback.""" + events, to_append = _build_budget_exhausted_fallback_events( + "Here's what I did so far..." + ) + assert events == [] + assert to_append == "" + + def test_whitespace_only_still_emits_fallback(self): + events, to_append = _build_budget_exhausted_fallback_events(" \n\t ") + assert len(events) == 3 + assert to_append == _BUDGET_EXHAUSTED_FALLBACK_TEXT + + def test_each_call_uses_fresh_block_id(self): + """Block IDs are UUIDs — two invocations must not collide.""" + events_a, _ = _build_budget_exhausted_fallback_events("") + events_b, _ = _build_budget_exhausted_fallback_events("") + assert events_a[0].id != events_b[0].id diff --git a/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py b/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py index 8d6fb50a53..8a9e435743 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py +++ b/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py @@ -63,21 +63,117 @@ def _make_session_messages(*roles: str) -> list[ChatMessage]: class TestResolveBaselineModel: - """Baseline model resolution honours the per-request tier toggle.""" + """Baseline model resolution honours the per-request tier toggle. - def test_advanced_tier_selects_advanced_model(self): - assert _resolve_baseline_model("advanced") == config.advanced_model + Baseline reads the ``fast_*_model`` cells of the (path, tier) matrix + and never falls through to the SDK-side ``thinking_*_model`` cells. + Without a user_id (so no LD context) the resolver returns the + ``ChatConfig`` static default; per-user overrides are exercised in + ``copilot/model_router_test.py``. + """ - def test_standard_tier_selects_default_model(self): - assert _resolve_baseline_model("standard") == config.model + @pytest.mark.asyncio + async def test_advanced_tier_selects_fast_advanced_model(self): + assert ( + await _resolve_baseline_model("advanced", None) + == config.fast_advanced_model + ) - def test_none_tier_selects_default_model(self): - """Baseline users without a tier MUST keep the default (standard).""" - assert _resolve_baseline_model(None) == config.model + @pytest.mark.asyncio + async def test_standard_tier_selects_fast_standard_model(self): + assert ( + await _resolve_baseline_model("standard", None) + == config.fast_standard_model + ) - def test_standard_and_advanced_models_differ(self): - """Advanced tier defaults to a different (Opus) model than standard.""" - assert config.model != config.advanced_model + @pytest.mark.asyncio + async def test_none_tier_selects_fast_standard_model(self): + """Baseline users without a tier get the fast-standard default.""" + assert await _resolve_baseline_model(None, None) == config.fast_standard_model + + def test_fast_standard_default_is_sonnet(self): + """Shipped default: Sonnet on the baseline standard cell — the + non-Anthropic routes ship via the LD flag instead of a config + change. Asserts the declared ``Field`` default so a deploy-time + ``CHAT_FAST_STANDARD_MODEL`` override doesn't flake CI.""" + from backend.copilot.config import ChatConfig + + assert ( + ChatConfig.model_fields["fast_standard_model"].default + == "anthropic/claude-sonnet-4-6" + ) + + def test_fast_advanced_default_is_opus(self): + """Shipped default: Opus on the baseline advanced cell — mirrors + the SDK advanced cell so the advanced-tier A/B stays clean + (same model, different path).""" + from backend.copilot.config import ChatConfig + + assert ( + ChatConfig.model_fields["fast_advanced_model"].default + == "anthropic/claude-opus-4.7" + ) + + def test_standard_and_advanced_cells_differ_on_fast(self): + """Advanced tier defaults to a different model than standard on + the baseline path. Checked against declared ``Field`` defaults + so operator env overrides don't flake the test.""" + from backend.copilot.config import ChatConfig + + assert ( + ChatConfig.model_fields["fast_standard_model"].default + != ChatConfig.model_fields["fast_advanced_model"].default + ) + + def test_legacy_env_aliases_route_to_new_fields(self, monkeypatch): + """Backward compat: the pre-split env var names must still bind. + + The four-field matrix was introduced with ``validation_alias`` + entries so that existing deployments setting ``CHAT_MODEL`` / + ``CHAT_ADVANCED_MODEL`` / ``CHAT_FAST_MODEL`` continue to override + the same effective cell without a rename. Construct a fresh + ``ChatConfig`` with each legacy name set and confirm it lands on + the new field. + """ + from backend.copilot.config import ChatConfig + + monkeypatch.setenv("CHAT_MODEL", "legacy/sonnet-via-chat-model") + monkeypatch.setenv("CHAT_ADVANCED_MODEL", "legacy/opus-via-advanced") + monkeypatch.setenv("CHAT_FAST_MODEL", "legacy/fast-via-fast-model") + + cfg = ChatConfig() + + assert cfg.thinking_standard_model == "legacy/sonnet-via-chat-model" + assert cfg.thinking_advanced_model == "legacy/opus-via-advanced" + assert cfg.fast_standard_model == "legacy/fast-via-fast-model" + + def test_all_four_new_env_vars_bind_to_their_cells(self, monkeypatch): + """Each of the four (path, tier) cells must be overridable via + its documented ``CHAT_*_*_MODEL`` env var — including + ``CHAT_FAST_ADVANCED_MODEL`` which was missing a + ``validation_alias`` in the original split and only bound + implicitly through ``env_prefix``. Pinning all four here so + that whenever someone touches the config shape, an accidental + unbinding fails CI instead of silently ignoring operator + overrides. + """ + from backend.copilot.config import ChatConfig + + monkeypatch.setenv("CHAT_FAST_STANDARD_MODEL", "explicit/fast-std") + monkeypatch.setenv("CHAT_FAST_ADVANCED_MODEL", "explicit/fast-adv") + monkeypatch.setenv("CHAT_THINKING_STANDARD_MODEL", "explicit/think-std") + monkeypatch.setenv("CHAT_THINKING_ADVANCED_MODEL", "explicit/think-adv") + # Clear the legacy aliases so they don't win priority in + # ``AliasChoices`` (first match wins). + for legacy in ("CHAT_MODEL", "CHAT_ADVANCED_MODEL", "CHAT_FAST_MODEL"): + monkeypatch.delenv(legacy, raising=False) + + cfg = ChatConfig() + + assert cfg.fast_standard_model == "explicit/fast-std" + assert cfg.fast_advanced_model == "explicit/fast-adv" + assert cfg.thinking_standard_model == "explicit/think-std" + assert cfg.thinking_advanced_model == "explicit/think-adv" class TestLoadPriorTranscript: diff --git a/autogpt_platform/backend/backend/copilot/builder_context.py b/autogpt_platform/backend/backend/copilot/builder_context.py new file mode 100644 index 0000000000..9f36350d1c --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/builder_context.py @@ -0,0 +1,217 @@ +"""Builder-session context helpers — split cacheable system prompt from +the volatile per-turn snapshot so Claude's prompt cache stays warm.""" + +from __future__ import annotations + +import logging +from typing import Any + +from backend.copilot.model import ChatSession +from backend.copilot.permissions import CopilotPermissions +from backend.copilot.tools.agent_generator import get_agent_as_json +from backend.copilot.tools.get_agent_building_guide import _load_guide + +logger = logging.getLogger(__name__) + + +BUILDER_CONTEXT_TAG = "builder_context" +BUILDER_SESSION_TAG = "builder_session" + + +# Tools hidden from builder-bound sessions: ``create_agent`` / +# ``customize_agent`` would mint a new graph (panel is bound to one), +# and ``get_agent_building_guide`` duplicates bytes already in the +# system-prompt suffix. Everything else (find_block, find_agent, …) +# stays available so the LLM can look up ids instead of hallucinating. +BUILDER_BLOCKED_TOOLS: tuple[str, ...] = ( + "create_agent", + "customize_agent", + "get_agent_building_guide", +) + + +def resolve_session_permissions( + session: ChatSession | None, +) -> CopilotPermissions | None: + """Blacklist :data:`BUILDER_BLOCKED_TOOLS` for builder-bound sessions, + return ``None`` (unrestricted) otherwise.""" + if session is None or not session.metadata.builder_graph_id: + return None + return CopilotPermissions( + tools=list(BUILDER_BLOCKED_TOOLS), + tools_exclude=True, + ) + + +# Caps — mirror the frontend ``serializeGraphForChat`` defaults so the +# server-side block stays within a practical token budget for large graphs. +_MAX_NODES = 100 +_MAX_LINKS = 200 + +_FETCH_FAILED_PREFIX = ( + f"<{BUILDER_CONTEXT_TAG}>\n" + f"fetch_failed\n" + f"\n\n" +) + +# Embedded in the cacheable suffix so the LLM picks the right run_agent +# dispatch mode without forcing the user to watch a long-blocking call. +_BUILDER_RUN_AGENT_GUIDANCE = ( + "You are operating inside the builder panel, not the standalone " + "copilot page. The builder page already subscribes to agent " + "executions the moment you return an execution_id, so for REAL " + "(non-dry) runs prefer `run_agent(dry_run=False, wait_for_result=0)` " + "— the user will see the run stream in the builder's execution panel " + "in-place and your turn ends immediately with the id. For DRY-RUNS " + "keep `dry_run=True, wait_for_result=120`: blocking is required so " + "you can inspect `execution.node_executions` and report the verdict " + "in the same turn." +) + + +def _sanitize_for_xml(value: Any) -> str: + """Escape XML special chars — mirrors ``sanitizeForXml`` in + ``BuilderChatPanel/helpers.ts``.""" + s = "" if value is None else str(value) + return ( + s.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'") + ) + + +def _node_display_name(node: dict[str, Any]) -> str: + """Prefer the user-set label (``input_default.name`` / ``metadata.title``); + fall back to the block id.""" + defaults = node.get("input_default") or {} + metadata = node.get("metadata") or {} + for key in ("name", "title", "label"): + value = defaults.get(key) or metadata.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + block_id = node.get("block_id") or "" + return block_id or "unknown" + + +def _format_nodes(nodes: list[dict[str, Any]]) -> str: + if not nodes: + return "\n" + visible = nodes[:_MAX_NODES] + lines = [] + for node in visible: + node_id = _sanitize_for_xml(node.get("id") or "") + name = _sanitize_for_xml(_node_display_name(node)) + block_id = _sanitize_for_xml(node.get("block_id") or "") + lines.append(f"- {node_id}: {name} ({block_id})") + extra = len(nodes) - len(visible) + if extra > 0: + lines.append(f"({extra} more not shown)") + body = "\n".join(lines) + return f"\n{body}\n" + + +def _format_links( + links: list[dict[str, Any]], + nodes: list[dict[str, Any]], +) -> str: + if not links: + return "\n" + name_by_id = {n.get("id"): _node_display_name(n) for n in nodes} + visible = links[:_MAX_LINKS] + lines = [] + for link in visible: + src_id = link.get("source_id") or "" + dst_id = link.get("sink_id") or "" + src_name = name_by_id.get(src_id, src_id) + dst_name = name_by_id.get(dst_id, dst_id) + src_out = link.get("source_name") or "" + dst_in = link.get("sink_name") or "" + lines.append( + f"- {_sanitize_for_xml(src_name)}.{_sanitize_for_xml(src_out)} " + f"-> {_sanitize_for_xml(dst_name)}.{_sanitize_for_xml(dst_in)}" + ) + extra = len(links) - len(visible) + if extra > 0: + lines.append(f"({extra} more not shown)") + body = "\n".join(lines) + return f"\n{body}\n" + + +async def build_builder_system_prompt_suffix(session: ChatSession) -> str: + """Return the cacheable system-prompt suffix for a builder session. + + Holds only static content (dispatch guidance + building guide) so the + bytes are identical across turns AND across sessions for different + graphs — the live id/name/version ride on the per-turn prefix. + """ + if not session.metadata.builder_graph_id: + return "" + + try: + guide = _load_guide() + except Exception: + logger.exception("[builder_context] Failed to load agent-building guide") + return "" + + # The guide is trusted server-side content (read from disk). We do NOT + # escape it — the LLM needs the raw markdown to make sense of block ids, + # code fences, and example JSON. + return ( + f"\n\n<{BUILDER_SESSION_TAG}>\n" + f"\n" + f"{_BUILDER_RUN_AGENT_GUIDANCE}\n" + f"\n" + f"\n{guide}\n\n" + f"" + ) + + +async def build_builder_context_turn_prefix( + session: ChatSession, + user_id: str | None, +) -> str: + """Return the per-turn ```` prefix with the live + graph snapshot (id/name/version/nodes/links). ``""`` for non-builder + sessions; fetch-failure marker if the graph cannot be read.""" + graph_id = session.metadata.builder_graph_id + if not graph_id: + return "" + + try: + agent_json = await get_agent_as_json(graph_id, user_id) + except Exception: + logger.exception( + "[builder_context] Failed to fetch graph %s for session %s", + graph_id, + session.session_id, + ) + return _FETCH_FAILED_PREFIX + + if not agent_json: + logger.warning( + "[builder_context] Graph %s not found for session %s", + graph_id, + session.session_id, + ) + return _FETCH_FAILED_PREFIX + + version = _sanitize_for_xml(agent_json.get("version") or "") + raw_name = agent_json.get("name") + graph_name = ( + raw_name.strip() if isinstance(raw_name, str) and raw_name.strip() else None + ) + nodes = agent_json.get("nodes") or [] + links = agent_json.get("links") or [] + name_attr = f' name="{_sanitize_for_xml(graph_name)}"' if graph_name else "" + graph_tag = ( + f'' + ) + + inner = f"{graph_tag}\n{_format_nodes(nodes)}\n{_format_links(links, nodes)}" + return f"<{BUILDER_CONTEXT_TAG}>\n{inner}\n\n\n" diff --git a/autogpt_platform/backend/backend/copilot/builder_context_test.py b/autogpt_platform/backend/backend/copilot/builder_context_test.py new file mode 100644 index 0000000000..efeb6f7dad --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/builder_context_test.py @@ -0,0 +1,329 @@ +"""Tests for the split builder-context helpers. + +Covers both halves of the public API: + +- :func:`build_builder_system_prompt_suffix` — session-stable block + appended to the system prompt (contains the guide + graph id/name). +- :func:`build_builder_context_turn_prefix` — per-turn user-message + prefix (contains the live version + node/link snapshot). +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, patch + +import pytest + +from backend.copilot.builder_context import ( + BUILDER_CONTEXT_TAG, + BUILDER_SESSION_TAG, + build_builder_context_turn_prefix, + build_builder_system_prompt_suffix, +) +from backend.copilot.model import ChatSession + + +def _session( + builder_graph_id: str | None, + *, + user_id: str = "test-user", +) -> ChatSession: + """Minimal ``ChatSession`` with *builder_graph_id* on metadata.""" + return ChatSession.new( + user_id, + dry_run=False, + builder_graph_id=builder_graph_id, + ) + + +def _agent_json( + nodes: list[dict] | None = None, + links: list[dict] | None = None, + **overrides, +) -> dict: + base: dict = { + "id": "graph-1", + "name": "My Agent", + "description": "A test agent", + "version": 3, + "is_active": True, + "nodes": nodes if nodes is not None else [], + "links": links if links is not None else [], + } + base.update(overrides) + return base + + +# --------------------------------------------------------------------------- +# build_builder_system_prompt_suffix +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_system_prompt_suffix_empty_for_non_builder(): + session = _session(None) + result = await build_builder_system_prompt_suffix(session) + assert result == "" + + +@pytest.mark.asyncio +async def test_system_prompt_suffix_contains_only_static_content(): + session = _session("graph-1") + with patch( + "backend.copilot.builder_context._load_guide", + return_value="# Guide body", + ): + suffix = await build_builder_system_prompt_suffix(session) + + assert suffix.startswith("\n\n") + assert f"<{BUILDER_SESSION_TAG}>" in suffix + assert f"" in suffix + assert "" in suffix + assert "# Guide body" in suffix + # Dispatch-mode guidance must appear so the LLM knows to prefer + # wait_for_result=0 for real runs (builder UI subscribes live) and + # wait_for_result=120 for dry-runs (so it can inspect the node trace). + assert "" in suffix + assert "wait_for_result=0" in suffix + assert "wait_for_result=120" in suffix + # Regression: dynamic graph id/name must NOT leak into the cacheable + # suffix — they live in the per-turn prefix so renames and cross-graph + # sessions don't invalidate Claude's prompt cache. + assert "graph-1" not in suffix + assert "id=" not in suffix + assert "name=" not in suffix + + +@pytest.mark.asyncio +async def test_system_prompt_suffix_identical_across_graphs(): + """The suffix must be byte-identical regardless of which graph the + session is bound to — that's what keeps the cacheable prefix warm + across sessions.""" + s1 = _session("graph-1") + s2 = _session("graph-2", user_id="different-owner") + with patch( + "backend.copilot.builder_context._load_guide", + return_value="# Guide body", + ): + suffix_1 = await build_builder_system_prompt_suffix(s1) + suffix_2 = await build_builder_system_prompt_suffix(s2) + + assert suffix_1 == suffix_2 + + +@pytest.mark.asyncio +async def test_system_prompt_suffix_empty_when_guide_load_fails(): + """Guide load failure means we have nothing useful to add — emit an + empty suffix rather than a half-built block.""" + session = _session("graph-1") + with patch( + "backend.copilot.builder_context._load_guide", + side_effect=OSError("missing"), + ): + suffix = await build_builder_system_prompt_suffix(session) + + assert suffix == "" + + +# --------------------------------------------------------------------------- +# build_builder_context_turn_prefix +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_turn_prefix_empty_for_non_builder(): + session = _session(None) + result = await build_builder_context_turn_prefix(session, "user-1") + assert result == "" + + +@pytest.mark.asyncio +async def test_turn_prefix_contains_version_nodes_and_links(): + session = _session("graph-1") + nodes = [ + { + "id": "n1", + "block_id": "block-A", + "input_default": {"name": "Input"}, + "metadata": {}, + }, + { + "id": "n2", + "block_id": "block-B", + "input_default": {}, + "metadata": {}, + }, + ] + links = [ + { + "source_id": "n1", + "sink_id": "n2", + "source_name": "out", + "sink_name": "in", + } + ] + agent = _agent_json(nodes=nodes, links=links) + with patch( + "backend.copilot.builder_context.get_agent_as_json", + new=AsyncMock(return_value=agent), + ): + block = await build_builder_context_turn_prefix(session, "user-1") + + assert block.startswith(f"<{BUILDER_CONTEXT_TAG}>\n") + assert block.endswith(f"\n\n") + assert 'id="graph-1"' in block + assert 'name="My Agent"' in block + assert 'version="3"' in block + assert 'node_count="2"' in block + assert 'edge_count="1"' in block + assert "n1: Input (block-A)" in block + assert "n2: block-B (block-B)" in block + assert "Input.out -> block-B.in" in block + + +@pytest.mark.asyncio +async def test_turn_prefix_does_not_include_guide(): + """The guide lives in the cacheable system prompt, not in the per-turn + prefix.""" + session = _session("graph-1") + with ( + patch( + "backend.copilot.builder_context.get_agent_as_json", + new=AsyncMock(return_value=_agent_json()), + ), + # Sentinel guide text — if it leaks into the turn prefix the + # assertion below catches it. + patch( + "backend.copilot.builder_context._load_guide", + return_value="SENTINEL_GUIDE_BODY", + ), + ): + block = await build_builder_context_turn_prefix(session, "user-1") + + assert "SENTINEL_GUIDE_BODY" not in block + assert "" not in block + + +@pytest.mark.asyncio +async def test_turn_prefix_escapes_graph_name(): + session = _session("graph-1") + with patch( + "backend.copilot.builder_context.get_agent_as_json", + new=AsyncMock(return_value=_agent_json(name='", - description: "", - hardcodedValues: {}, - inputSchema: {}, - outputSchema: {}, - uiType: 1, - block_id: "b1", - costs: [], - categories: [], - }, - type: "custom" as const, - position: { x: 0, y: 0 }, - }, - ] as unknown as CustomNode[]; - - const result = serializeGraphForChat(nodes, []); - expect(result).not.toContain("`; - const wrapped = wrapWithHeadInjection(content, tailwindScript); + const wrapped = wrapWithHeadInjection( + content, + tailwindScript + FRAGMENT_LINK_INTERCEPTOR_SCRIPT, + ); return (