mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
fix(platform): resolve merge conflicts with dev, update tier names
- Resolve conflicts in test files and SubscriptionTierSection - Update TIER_WORKSPACE_STORAGE_MB for renamed tiers: FREE→BASIC, add MAX - Update tier descriptions in helpers.ts with storage limits - Update rate_limit_test.py for new tier names Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
245
.claude/skills/pr-polish/SKILL.md
Normal file
245
.claude/skills/pr-polish/SKILL.md
Normal file
@@ -0,0 +1,245 @@
|
||||
---
|
||||
name: pr-polish
|
||||
description: Alternate /pr-review and /pr-address on a PR until the PR is truly mergeable — no new review findings, zero unresolved inline threads, zero unaddressed top-level reviews or issue comments, all CI checks green, and two consecutive quiet polls after CI settles. Use when the user wants a PR polished to merge-ready without setting a fixed number of rounds.
|
||||
user-invocable: true
|
||||
argument-hint: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# PR Polish
|
||||
|
||||
**Goal.** Drive a PR to merge-ready by alternating `/pr-review` and `/pr-address` until **all** of the following hold:
|
||||
|
||||
1. The most recent `/pr-review` produces **zero new findings** (no new inline comments, no new top-level reviews with a non-empty body).
|
||||
2. Every inline review thread reachable via GraphQL reports `isResolved: true`.
|
||||
3. Every non-bot, non-author top-level review has been acknowledged (replied-to) OR resolved via a thread it spawned.
|
||||
4. Every non-bot, non-author issue comment has been acknowledged (replied-to).
|
||||
5. Every CI check is `conclusion: "success"` or `"skipped"` / `"neutral"` — none `"failure"` or still pending.
|
||||
6. **Two consecutive post-CI polls** (≥60s apart) stay clean — no new threads, no new non-empty reviews, no new issue comments. Bots (coderabbitai, sentry, autogpt-reviewer) frequently post late after CI settles; a single green snapshot is not sufficient.
|
||||
|
||||
**Do not stop at a fixed number of rounds.** If round N introduces new comments, round N+1 is required. Cap at `_MAX_ROUNDS = 10` as a safety valve, but expect 2–5 in practice.
|
||||
|
||||
## TodoWrite
|
||||
|
||||
Before starting, write two todos so the user can see the loop progression:
|
||||
|
||||
- `Round {current}: /pr-review + /pr-address on PR #{N}` — current iteration.
|
||||
- `Final polish polling: 2 consecutive clean polls, CI green, 0 unresolved` — runs after the last non-empty review round.
|
||||
|
||||
Update the `current` round counter at the start of each iteration; mark `completed` only when the round's address step finishes (all new threads addressed + resolved).
|
||||
|
||||
## Find the PR
|
||||
|
||||
```bash
|
||||
ARG_PR="${ARG:-}"
|
||||
# Normalize URL → numeric ID if the skill arg is a pull-request URL.
|
||||
if [[ "$ARG_PR" =~ ^https?://github\.com/[^/]+/[^/]+/pull/([0-9]+) ]]; then
|
||||
ARG_PR="${BASH_REMATCH[1]}"
|
||||
fi
|
||||
PR="${ARG_PR:-$(gh pr list --head "$(git branch --show-current)" --repo Significant-Gravitas/AutoGPT --json number --jq '.[0].number')}"
|
||||
if [ -z "$PR" ] || [ "$PR" = "null" ]; then
|
||||
echo "No PR found for current branch. Provide a PR number or URL as the skill arg."
|
||||
exit 1
|
||||
fi
|
||||
echo "Polishing PR #$PR"
|
||||
```
|
||||
|
||||
## The outer loop
|
||||
|
||||
```text
|
||||
round = 0
|
||||
while round < _MAX_ROUNDS:
|
||||
round += 1
|
||||
baseline = snapshot_state(PR) # see "Snapshotting state" below
|
||||
invoke_skill("pr-review", PR) # posts findings as inline comments / top-level review
|
||||
findings = diff_state(PR, baseline)
|
||||
if findings.total == 0:
|
||||
break # no new findings → go to polish polling
|
||||
invoke_skill("pr-address", PR) # resolves every unresolved thread + CI failure
|
||||
# Post-loop: polish polling (see below).
|
||||
polish_polling(PR)
|
||||
```
|
||||
|
||||
### Snapshotting state
|
||||
|
||||
Before each `/pr-review`, capture a baseline so the diff after the review reflects **only** what the review just added (not pre-existing threads):
|
||||
|
||||
```bash
|
||||
# Inline threads — total count + latest databaseId per thread
|
||||
gh api graphql -f query="
|
||||
{
|
||||
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
|
||||
pullRequest(number: ${PR}) {
|
||||
reviewThreads(first: 100) {
|
||||
totalCount
|
||||
nodes {
|
||||
id
|
||||
isResolved
|
||||
comments(last: 1) { nodes { databaseId } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}" > /tmp/baseline_threads.json
|
||||
|
||||
# Top-level reviews — count + latest id per non-empty review
|
||||
gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}/reviews" --paginate \
|
||||
--jq '[.[] | select((.body // "") != "") | {id, user: .user.login, state, submitted_at}]' \
|
||||
> /tmp/baseline_reviews.json
|
||||
|
||||
# Issue comments — count + latest id per non-bot, non-author comment.
|
||||
# Bots are filtered by User.type == "Bot" (GitHub sets this for app/bot
|
||||
# accounts like coderabbitai, github-actions, sentry-io). The author is
|
||||
# filtered by comparing login to the PR author — export it so jq can see it.
|
||||
AUTHOR=$(gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}" --jq '.user.login')
|
||||
gh api "repos/Significant-Gravitas/AutoGPT/issues/${PR}/comments" --paginate \
|
||||
--jq --arg author "$AUTHOR" \
|
||||
'[.[] | select(.user.type != "Bot" and .user.login != $author)
|
||||
| {id, user: .user.login, created_at}]' \
|
||||
> /tmp/baseline_issue_comments.json
|
||||
```
|
||||
|
||||
### Diffing after a review
|
||||
|
||||
After `/pr-review` runs, any of these counting as "new findings" means another address round is needed:
|
||||
|
||||
- New inline thread `id` not in the baseline.
|
||||
- An existing thread whose latest comment `databaseId` is higher than the baseline's (new reply on an old thread).
|
||||
- A new top-level review `id` with a non-empty body.
|
||||
- A new issue comment `id` from a non-bot, non-author user.
|
||||
|
||||
If any of the four buckets is non-empty → not done; invoke `/pr-address` and loop.
|
||||
|
||||
## Polish polling
|
||||
|
||||
Once `/pr-review` produces zero new findings, do **not** exit yet. Bots (coderabbitai, sentry, autogpt-reviewer) commonly post late reviews after CI settles — 30–90 seconds after the final push. Poll at 60-second intervals:
|
||||
|
||||
```text
|
||||
NON_SUCCESS_TERMINAL = {"failure", "cancelled", "timed_out", "action_required", "startup_failure"}
|
||||
clean_polls = 0
|
||||
required_clean = 2
|
||||
while clean_polls < required_clean:
|
||||
# 1. CI gate — any terminal non-success conclusion (not just "failure")
|
||||
# must trigger /pr-address. "success", "skipped", "neutral" are clean;
|
||||
# anything else (including cancelled, timed_out, action_required) is a
|
||||
# blocker that won't self-resolve.
|
||||
ci = fetch_check_runs(PR)
|
||||
if any ci.conclusion in NON_SUCCESS_TERMINAL:
|
||||
invoke_skill("pr-address", PR) # address failures + any new comments
|
||||
baseline = snapshot_state(PR) # reset — push during address invalidates old baseline
|
||||
clean_polls = 0
|
||||
continue
|
||||
if any ci.conclusion is None (still in_progress):
|
||||
sleep 60; continue # wait without counting this as clean
|
||||
|
||||
# 2. Comment / thread gate
|
||||
threads = fetch_unresolved_threads(PR)
|
||||
new_issue_comments = diff_against_baseline(issue_comments)
|
||||
new_reviews = diff_against_baseline(reviews)
|
||||
if threads or new_issue_comments or new_reviews:
|
||||
invoke_skill("pr-address", PR)
|
||||
baseline = snapshot_state(PR) # reset — the address loop just dealt with these,
|
||||
# otherwise they stay "new" relative to the old baseline forever
|
||||
clean_polls = 0
|
||||
continue
|
||||
|
||||
# 3. Mergeability gate
|
||||
mergeable = gh api repos/.../pulls/${PR} --jq '.mergeable'
|
||||
if mergeable == false (CONFLICTING):
|
||||
resolve_conflicts(PR) # see pr-address skill
|
||||
clean_polls = 0
|
||||
continue
|
||||
if mergeable is null (UNKNOWN):
|
||||
sleep 60; continue
|
||||
|
||||
clean_polls += 1
|
||||
sleep 60
|
||||
```
|
||||
|
||||
Only after `clean_polls == 2` do you report `ORCHESTRATOR:DONE`.
|
||||
|
||||
### Why 2 clean polls, not 1
|
||||
|
||||
A single green snapshot can be misleading — the final CI check often completes ~30s before a bot posts its delayed review. One quiet cycle does not prove the PR is stable; two consecutive cycles with no new threads, reviews, or issue comments arriving gives high confidence nothing else is incoming.
|
||||
|
||||
### Why checking every source each poll
|
||||
|
||||
`/pr-address` polling inside a single round already re-checks its own comments, but `/pr-polish` sits a level above and must also catch:
|
||||
|
||||
- New top-level reviews (autogpt-reviewer sometimes posts structured feedback only after several CI green cycles).
|
||||
- Issue comments from human reviewers (not caught by inline thread polling).
|
||||
- Sentry bug predictions that land on new line numbers post-push.
|
||||
- Merge conflicts introduced by a race between your push and a merge to `dev`.
|
||||
|
||||
## Invocation pattern
|
||||
|
||||
Delegate to existing skills with the `Skill` tool; do not re-implement the review or address logic inline. This keeps the polish loop focused on orchestration and lets the child skills evolve independently.
|
||||
|
||||
```python
|
||||
Skill(skill="pr-review", args=pr_url)
|
||||
Skill(skill="pr-address", args=pr_url)
|
||||
```
|
||||
|
||||
After each child invocation, re-query GitHub state directly — never trust a summary for the stop condition. The orchestrator's `ORCHESTRATOR:DONE` is verified against actual GraphQL / REST responses per the rules in `pr-address`'s "Verify actual count before outputting ORCHESTRATOR:DONE" section.
|
||||
|
||||
### **Auto-continue: do NOT end your response between child skills**
|
||||
|
||||
`/pr-polish` is a single orchestration task — one invocation drives the PR all the way to merge-ready. When a child `Skill()` call returns control to you:
|
||||
|
||||
- Do NOT summarize and stop.
|
||||
- Do NOT wait for user confirmation to continue.
|
||||
- Immediately, in the same response, perform the next loop step: state diff → decide next action → next `Skill()` call or polling sleep.
|
||||
|
||||
The child skill returning is a **loop iteration boundary**, not a conversation turn boundary. You are expected to keep going until one of the exit conditions in the opening section is met (2 consecutive clean polls, `_MAX_ROUNDS` hit, or an unrecoverable error).
|
||||
|
||||
If the user needs to approve a risky action mid-loop (e.g., a force-push or a destructive git operation), pause there — but not at the routine "round N finished, round N+1 needed" boundary. Those are silent transitions.
|
||||
|
||||
## GitHub rate limits
|
||||
|
||||
This skill issues many GraphQL calls (one review-thread query per outer iteration plus per-poll queries inside polish polling). Expect the GraphQL budget to be tight on large PRs. When `gh api rate_limit --jq .resources.graphql.remaining` drops below ~200, back off:
|
||||
|
||||
- Fall back to REST for reads (flat `/pulls/{N}/comments`, `/pulls/{N}/reviews`, `/issues/{N}/comments`) per the `pr-address` skill's GraphQL-fallback section.
|
||||
- Queue thread resolutions (GraphQL-only) until the budget resets; keep making progress on fixes + REST replies meanwhile.
|
||||
- `sleep 5` between any batch of ≥20 writes to avoid secondary rate limits.
|
||||
|
||||
## Safety valves
|
||||
|
||||
- `_MAX_ROUNDS = 10` — if review+address rounds exceed this, stop and escalate to the user with a summary of what's still unresolved. A PR that cannot converge in 10 rounds has systemic issues that need human judgment.
|
||||
- After each commit, run `poetry run format` / `pnpm format && pnpm lint && pnpm types` per the target codebase's conventions. A failing format check is CI `failure` that will never self-resolve.
|
||||
- Every `/pr-review` round checks for **duplicate** concerns first (via `pr-review`'s own "Fetch existing review comments" step) so the loop does not re-post the same finding that a prior round already resolved.
|
||||
|
||||
## Reporting
|
||||
|
||||
When the skill finishes (either via two clean polls or hitting `_MAX_ROUNDS`), produce a compact summary:
|
||||
|
||||
```
|
||||
PR #{N} polish complete ({rounds_completed} rounds):
|
||||
- {X} inline threads opened and resolved
|
||||
- {Y} CI failures fixed
|
||||
- {Z} new commits pushed
|
||||
Final state: CI green, {total} threads all resolved, mergeable.
|
||||
```
|
||||
|
||||
If exiting via `_MAX_ROUNDS`, flag explicitly:
|
||||
|
||||
```
|
||||
PR #{N} polish stopped at {_MAX_ROUNDS} rounds — NOT merge-ready:
|
||||
- {N} threads still unresolved: {titles}
|
||||
- CI status: {summary}
|
||||
Needs human review.
|
||||
```
|
||||
|
||||
## When to use this skill
|
||||
|
||||
Use when the user says any of:
|
||||
- "polish this PR"
|
||||
- "keep reviewing and addressing until it's mergeable"
|
||||
- "loop /pr-review + /pr-address until done"
|
||||
- "make sure the PR is actually merge-ready"
|
||||
|
||||
Do **not** use when:
|
||||
- User wants just one review pass (→ `/pr-review`).
|
||||
- User wants to address already-posted comments without further self-review (→ `/pr-address`).
|
||||
- A fixed round count is explicitly requested (e.g., "do 3 rounds") — honour the count instead of converging.
|
||||
@@ -186,7 +186,7 @@ Multiple worktrees share the same host — Docker infra (postgres, redis, clamav
|
||||
|
||||
### Lock file contract
|
||||
|
||||
Path (**always** the root worktree so all siblings see it): `/Users/majdyz/Code/AutoGPT/.ign.testing.lock`
|
||||
Path (**always** the root worktree so all siblings see it): `$REPO_ROOT/.ign.testing.lock`
|
||||
|
||||
Body (one `key=value` per line):
|
||||
```
|
||||
@@ -202,7 +202,7 @@ intent=<one-line description + rough duration>
|
||||
### Claim
|
||||
|
||||
```bash
|
||||
LOCK=/Users/majdyz/Code/AutoGPT/.ign.testing.lock
|
||||
LOCK=$REPO_ROOT/.ign.testing.lock
|
||||
NOW=$(date -u +%Y-%m-%dT%H:%MZ)
|
||||
STALE_AFTER_MIN=5
|
||||
|
||||
@@ -252,7 +252,7 @@ echo "$HEARTBEAT_PID" > /tmp/pr-test-heartbeat.pid
|
||||
kill "$HEARTBEAT_PID" 2>/dev/null
|
||||
rm -f "$LOCK" /tmp/pr-test-heartbeat.pid
|
||||
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock" \
|
||||
>> /Users/majdyz/Code/AutoGPT/.ign.testing.log
|
||||
>> $REPO_ROOT/.ign.testing.log
|
||||
```
|
||||
|
||||
Use a `trap` so release runs even on `exit 1`:
|
||||
@@ -260,12 +260,38 @@ Use a `trap` so release runs even on `exit 1`:
|
||||
trap 'kill "$HEARTBEAT_PID" 2>/dev/null; rm -f "$LOCK"' EXIT INT TERM
|
||||
```
|
||||
|
||||
### **Release the lock AS SOON AS the test run is done**
|
||||
|
||||
The lock guards **test execution**, not **app lifecycle**. Once Step 5 (record results) and Step 6 (post PR comment) are complete, release the lock IMMEDIATELY — even if:
|
||||
|
||||
- The native `poetry run app` / `pnpm dev` processes are still running so the user can keep poking at the app manually.
|
||||
- You're leaving docker containers up.
|
||||
- You're tailing logs for a minute or two.
|
||||
|
||||
Keeping the lock held past the test run is the single most common way `/pr-test` stalls other agents. **The app staying up is orthogonal to the lock; don't conflate them.** Sibling worktrees running their own `/pr-test` will kill the stray processes and free the ports themselves (Step 3c/3e-native handle that) — they just need the lock file gone.
|
||||
|
||||
Concretely, the sequence at the end of every `/pr-test` run (success or failure) is:
|
||||
|
||||
```bash
|
||||
# 1. Write the final report + post PR comment — done above in Step 5/6.
|
||||
# 2. Release the lock right now, even if the app is still up.
|
||||
kill "$HEARTBEAT_PID" 2>/dev/null
|
||||
rm -f "$LOCK" /tmp/pr-test-heartbeat.pid
|
||||
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock (app may still be running)" \
|
||||
>> $REPO_ROOT/.ign.testing.log
|
||||
# 3. Optionally leave the app running and note it so the user knows:
|
||||
echo "Native stack still running on :3000 / :8006 for manual poking. Kill with:"
|
||||
echo " pkill -9 -f 'poetry run app'; pkill -9 -f 'next-server|next dev'"
|
||||
```
|
||||
|
||||
If a sibling agent's `/pr-test` needs to take over, it'll do the kill+rebuild dance from Step 3c/3e-native on its own — your only job is to not hold the lock file past the end of your test.
|
||||
|
||||
### Shared status log
|
||||
|
||||
`/Users/majdyz/Code/AutoGPT/.ign.testing.log` is an append-only channel any agent can read/write. Use it for "I'm waiting", "I'm done, resources free", or post-run notes:
|
||||
`$REPO_ROOT/.ign.testing.log` is an append-only channel any agent can read/write. Use it for "I'm waiting", "I'm done, resources free", or post-run notes:
|
||||
```bash
|
||||
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] <message>" \
|
||||
>> /Users/majdyz/Code/AutoGPT/.ign.testing.log
|
||||
>> $REPO_ROOT/.ign.testing.log
|
||||
```
|
||||
|
||||
## Step 3: Environment setup
|
||||
@@ -755,6 +781,19 @@ Upload screenshots to the PR using the GitHub Git API (no local git operations
|
||||
|
||||
**CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.** Every screenshot MUST appear as `` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes.
|
||||
|
||||
**CRITICAL — NEVER paste absolute local paths into the PR comment.** Strings like `/Users/…`, `/home/…`, `C:\…` are useless to every reviewer except you. Before posting, grep the final body for `/Users/`, `/home/`, `/tmp/`, `/private/`, `C:\`, `~/` and either drop those lines entirely or rewrite them as repo-relative paths (`autogpt_platform/backend/…`). The PR comment is an artifact reviewers on GitHub read — it must be self-contained on github.com. Keep local paths in `$RESULTS_DIR/test-report.md` for yourself; only copy the *content* they reference (excerpts, test names, log lines) into the PR comment, not the path.
|
||||
|
||||
**Pre-post sanity check** (paste after building the comment body, before `gh api ... comments`):
|
||||
|
||||
```bash
|
||||
# Reject any local-looking absolute path or home-dir shortcut in the body
|
||||
if grep -nE '(^|[^A-Za-z])(/Users/|/home/|/tmp/|/private/|C:\\|~/)[A-Za-z0-9]' "$COMMENT_FILE" ; then
|
||||
echo "ABORT: local filesystem paths detected in PR comment body."
|
||||
echo "Remove or rewrite as repo-relative (autogpt_platform/...) before posting."
|
||||
exit 1
|
||||
fi
|
||||
```
|
||||
|
||||
```bash
|
||||
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
|
||||
REPO="Significant-Gravitas/AutoGPT"
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -195,3 +195,4 @@ test.db
|
||||
# Implementation plans (generated by AI agents)
|
||||
plans/
|
||||
.claude/worktrees/
|
||||
test-results/
|
||||
|
||||
3
autogpt_platform/.gitignore
vendored
3
autogpt_platform/.gitignore
vendored
@@ -1,3 +1,6 @@
|
||||
*.ignore.*
|
||||
*.ign.*
|
||||
.application.logs
|
||||
|
||||
# Claude Code local settings only — the rest of .claude/ is shared (skills etc.)
|
||||
.claude/settings.local.json
|
||||
|
||||
@@ -59,6 +59,8 @@ class OAuthState(BaseModel):
|
||||
code_verifier: Optional[str] = None
|
||||
scopes: list[str]
|
||||
"""Unix timestamp (seconds) indicating when this OAuth state expires"""
|
||||
credential_id: Optional[str] = None
|
||||
"""If set, this OAuth flow upgrades an existing credential's scopes."""
|
||||
|
||||
|
||||
class UserMetadata(BaseModel):
|
||||
|
||||
@@ -179,6 +179,9 @@ MEM0_API_KEY=
|
||||
OPENWEATHERMAP_API_KEY=
|
||||
GOOGLE_MAPS_API_KEY=
|
||||
|
||||
# Platform Bot Linking
|
||||
PLATFORM_LINK_BASE_URL=http://localhost:3000/link
|
||||
|
||||
# Communication Services
|
||||
DISCORD_BOT_TOKEN=
|
||||
MEDIUM_API_KEY=
|
||||
|
||||
@@ -0,0 +1,932 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from autogpt_libs.auth import requires_admin_user
|
||||
from autogpt_libs.auth.models import User as AuthUser
|
||||
from fastapi import APIRouter, HTTPException, Security
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.admin.model import (
|
||||
AgentDiagnosticsResponse,
|
||||
ExecutionDiagnosticsResponse,
|
||||
)
|
||||
from backend.data.diagnostics import (
|
||||
FailedExecutionDetail,
|
||||
OrphanedScheduleDetail,
|
||||
RunningExecutionDetail,
|
||||
ScheduleDetail,
|
||||
ScheduleHealthMetrics,
|
||||
cleanup_all_stuck_queued_executions,
|
||||
cleanup_orphaned_executions_bulk,
|
||||
cleanup_orphaned_schedules_bulk,
|
||||
get_agent_diagnostics,
|
||||
get_all_orphaned_execution_ids,
|
||||
get_all_schedules_details,
|
||||
get_all_stuck_queued_execution_ids,
|
||||
get_execution_diagnostics,
|
||||
get_failed_executions_count,
|
||||
get_failed_executions_details,
|
||||
get_invalid_executions_details,
|
||||
get_long_running_executions_details,
|
||||
get_orphaned_executions_details,
|
||||
get_orphaned_schedules_details,
|
||||
get_running_executions_details,
|
||||
get_schedule_health_metrics,
|
||||
get_stuck_queued_executions_details,
|
||||
stop_all_long_running_executions,
|
||||
)
|
||||
from backend.data.execution import get_graph_executions
|
||||
from backend.executor.utils import add_graph_execution, stop_graph_execution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["diagnostics", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
class RunningExecutionsListResponse(BaseModel):
|
||||
"""Response model for list of running executions"""
|
||||
|
||||
executions: List[RunningExecutionDetail]
|
||||
total: int
|
||||
|
||||
|
||||
class FailedExecutionsListResponse(BaseModel):
|
||||
"""Response model for list of failed executions"""
|
||||
|
||||
executions: List[FailedExecutionDetail]
|
||||
total: int
|
||||
|
||||
|
||||
class StopExecutionRequest(BaseModel):
|
||||
"""Request model for stopping a single execution"""
|
||||
|
||||
execution_id: str
|
||||
|
||||
|
||||
class StopExecutionsRequest(BaseModel):
|
||||
"""Request model for stopping multiple executions"""
|
||||
|
||||
execution_ids: List[str]
|
||||
|
||||
|
||||
class StopExecutionResponse(BaseModel):
|
||||
"""Response model for stop execution operations"""
|
||||
|
||||
success: bool
|
||||
stopped_count: int = 0
|
||||
message: str
|
||||
|
||||
|
||||
class RequeueExecutionResponse(BaseModel):
|
||||
"""Response model for requeue execution operations"""
|
||||
|
||||
success: bool
|
||||
requeued_count: int = 0
|
||||
message: str
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions",
|
||||
response_model=ExecutionDiagnosticsResponse,
|
||||
summary="Get Execution Diagnostics",
|
||||
)
|
||||
async def get_execution_diagnostics_endpoint():
|
||||
"""
|
||||
Get comprehensive diagnostic information about execution status.
|
||||
|
||||
Returns all execution metrics including:
|
||||
- Current state (running, queued)
|
||||
- Orphaned executions (>24h old, likely not in executor)
|
||||
- Failure metrics (1h, 24h, rate)
|
||||
- Long-running detection (stuck >1h, >24h)
|
||||
- Stuck queued detection
|
||||
- Throughput metrics (completions/hour)
|
||||
- RabbitMQ queue depths
|
||||
"""
|
||||
logger.info("Getting execution diagnostics")
|
||||
|
||||
diagnostics = await get_execution_diagnostics()
|
||||
|
||||
response = ExecutionDiagnosticsResponse(
|
||||
running_executions=diagnostics.running_count,
|
||||
queued_executions_db=diagnostics.queued_db_count,
|
||||
queued_executions_rabbitmq=diagnostics.rabbitmq_queue_depth,
|
||||
cancel_queue_depth=diagnostics.cancel_queue_depth,
|
||||
orphaned_running=diagnostics.orphaned_running,
|
||||
orphaned_queued=diagnostics.orphaned_queued,
|
||||
failed_count_1h=diagnostics.failed_count_1h,
|
||||
failed_count_24h=diagnostics.failed_count_24h,
|
||||
failure_rate_24h=diagnostics.failure_rate_24h,
|
||||
stuck_running_24h=diagnostics.stuck_running_24h,
|
||||
stuck_running_1h=diagnostics.stuck_running_1h,
|
||||
oldest_running_hours=diagnostics.oldest_running_hours,
|
||||
stuck_queued_1h=diagnostics.stuck_queued_1h,
|
||||
queued_never_started=diagnostics.queued_never_started,
|
||||
invalid_queued_with_start=diagnostics.invalid_queued_with_start,
|
||||
invalid_running_without_start=diagnostics.invalid_running_without_start,
|
||||
completed_1h=diagnostics.completed_1h,
|
||||
completed_24h=diagnostics.completed_24h,
|
||||
throughput_per_hour=diagnostics.throughput_per_hour,
|
||||
timestamp=diagnostics.timestamp,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Execution diagnostics: running={diagnostics.running_count}, "
|
||||
f"queued_db={diagnostics.queued_db_count}, "
|
||||
f"orphaned={diagnostics.orphaned_running + diagnostics.orphaned_queued}, "
|
||||
f"failed_24h={diagnostics.failed_count_24h}"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/agents",
|
||||
response_model=AgentDiagnosticsResponse,
|
||||
summary="Get Agent Diagnostics",
|
||||
)
|
||||
async def get_agent_diagnostics_endpoint():
|
||||
"""
|
||||
Get diagnostic information about agents.
|
||||
|
||||
Returns:
|
||||
- agents_with_active_executions: Number of unique agents with running/queued executions
|
||||
- timestamp: Current timestamp
|
||||
"""
|
||||
logger.info("Getting agent diagnostics")
|
||||
|
||||
diagnostics = await get_agent_diagnostics()
|
||||
|
||||
response = AgentDiagnosticsResponse(
|
||||
agents_with_active_executions=diagnostics.agents_with_active_executions,
|
||||
timestamp=diagnostics.timestamp,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Agent diagnostics: with_active_executions={diagnostics.agents_with_active_executions}"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions/running",
|
||||
response_model=RunningExecutionsListResponse,
|
||||
summary="List Running Executions",
|
||||
)
|
||||
async def list_running_executions(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Get detailed list of running and queued executions (recent, likely active).
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return (default 100)
|
||||
offset: Number of executions to skip (default 0)
|
||||
|
||||
Returns:
|
||||
List of running executions with details
|
||||
"""
|
||||
logger.info(f"Listing running executions (limit={limit}, offset={offset})")
|
||||
|
||||
executions = await get_running_executions_details(limit=limit, offset=offset)
|
||||
|
||||
# Get total count for pagination
|
||||
diagnostics = await get_execution_diagnostics()
|
||||
total = diagnostics.running_count + diagnostics.queued_db_count
|
||||
|
||||
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions/orphaned",
|
||||
response_model=RunningExecutionsListResponse,
|
||||
summary="List Orphaned Executions",
|
||||
)
|
||||
async def list_orphaned_executions(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Get detailed list of orphaned executions (>24h old, likely not in executor).
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return (default 100)
|
||||
offset: Number of executions to skip (default 0)
|
||||
|
||||
Returns:
|
||||
List of orphaned executions with details
|
||||
"""
|
||||
logger.info(f"Listing orphaned executions (limit={limit}, offset={offset})")
|
||||
|
||||
executions = await get_orphaned_executions_details(limit=limit, offset=offset)
|
||||
|
||||
# Get total count for pagination
|
||||
diagnostics = await get_execution_diagnostics()
|
||||
total = diagnostics.orphaned_running + diagnostics.orphaned_queued
|
||||
|
||||
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions/failed",
|
||||
response_model=FailedExecutionsListResponse,
|
||||
summary="List Failed Executions",
|
||||
)
|
||||
async def list_failed_executions(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
hours: int = 24,
|
||||
):
|
||||
"""
|
||||
Get detailed list of failed executions.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return (default 100)
|
||||
offset: Number of executions to skip (default 0)
|
||||
hours: Number of hours to look back (default 24)
|
||||
|
||||
Returns:
|
||||
List of failed executions with error details
|
||||
"""
|
||||
logger.info(
|
||||
f"Listing failed executions (limit={limit}, offset={offset}, hours={hours})"
|
||||
)
|
||||
|
||||
executions = await get_failed_executions_details(
|
||||
limit=limit, offset=offset, hours=hours
|
||||
)
|
||||
|
||||
# Get total count for pagination
|
||||
# Always count actual total for given hours parameter
|
||||
total = await get_failed_executions_count(hours=hours)
|
||||
|
||||
return FailedExecutionsListResponse(executions=executions, total=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions/long-running",
|
||||
response_model=RunningExecutionsListResponse,
|
||||
summary="List Long-Running Executions",
|
||||
)
|
||||
async def list_long_running_executions(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Get detailed list of long-running executions (RUNNING status >24h).
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return (default 100)
|
||||
offset: Number of executions to skip (default 0)
|
||||
|
||||
Returns:
|
||||
List of long-running executions with details
|
||||
"""
|
||||
logger.info(f"Listing long-running executions (limit={limit}, offset={offset})")
|
||||
|
||||
executions = await get_long_running_executions_details(limit=limit, offset=offset)
|
||||
|
||||
# Get total count for pagination
|
||||
diagnostics = await get_execution_diagnostics()
|
||||
total = diagnostics.stuck_running_24h
|
||||
|
||||
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions/stuck-queued",
|
||||
response_model=RunningExecutionsListResponse,
|
||||
summary="List Stuck Queued Executions",
|
||||
)
|
||||
async def list_stuck_queued_executions(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Get detailed list of stuck queued executions (QUEUED >1h, never started).
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return (default 100)
|
||||
offset: Number of executions to skip (default 0)
|
||||
|
||||
Returns:
|
||||
List of stuck queued executions with details
|
||||
"""
|
||||
logger.info(f"Listing stuck queued executions (limit={limit}, offset={offset})")
|
||||
|
||||
executions = await get_stuck_queued_executions_details(limit=limit, offset=offset)
|
||||
|
||||
# Get total count for pagination
|
||||
diagnostics = await get_execution_diagnostics()
|
||||
total = diagnostics.stuck_queued_1h
|
||||
|
||||
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions/invalid",
|
||||
response_model=RunningExecutionsListResponse,
|
||||
summary="List Invalid Executions",
|
||||
)
|
||||
async def list_invalid_executions(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Get detailed list of executions in invalid states (READ-ONLY).
|
||||
|
||||
Invalid states indicate data corruption and require manual investigation:
|
||||
- QUEUED but has startedAt (impossible - can't start while queued)
|
||||
- RUNNING but no startedAt (impossible - can't run without starting)
|
||||
|
||||
⚠️ NO BULK ACTIONS PROVIDED - These need case-by-case investigation.
|
||||
|
||||
Each invalid execution likely has a different root cause (crashes, race conditions,
|
||||
DB corruption). Investigate the execution history and logs to determine appropriate
|
||||
action (manual cleanup, status fix, or leave as-is if system recovered).
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return (default 100)
|
||||
offset: Number of executions to skip (default 0)
|
||||
|
||||
Returns:
|
||||
List of invalid state executions with details
|
||||
"""
|
||||
logger.info(f"Listing invalid state executions (limit={limit}, offset={offset})")
|
||||
|
||||
executions = await get_invalid_executions_details(limit=limit, offset=offset)
|
||||
|
||||
# Get total count for pagination
|
||||
diagnostics = await get_execution_diagnostics()
|
||||
total = (
|
||||
diagnostics.invalid_queued_with_start
|
||||
+ diagnostics.invalid_running_without_start
|
||||
)
|
||||
|
||||
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/requeue",
|
||||
response_model=RequeueExecutionResponse,
|
||||
summary="Requeue Stuck Execution",
|
||||
)
|
||||
async def requeue_single_execution(
|
||||
request: StopExecutionRequest, # Reuse same request model (has execution_id)
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Requeue a stuck QUEUED execution (admin only).
|
||||
|
||||
Uses add_graph_execution with existing graph_exec_id to requeue.
|
||||
|
||||
⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits.
|
||||
|
||||
Args:
|
||||
request: Contains execution_id to requeue
|
||||
|
||||
Returns:
|
||||
Success status and message
|
||||
"""
|
||||
logger.info(f"Admin {user.user_id} requeueing execution {request.execution_id}")
|
||||
|
||||
# Get the execution (validation - must be QUEUED)
|
||||
executions = await get_graph_executions(
|
||||
graph_exec_id=request.execution_id,
|
||||
statuses=[AgentExecutionStatus.QUEUED],
|
||||
)
|
||||
|
||||
if not executions:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Execution not found or not in QUEUED status",
|
||||
)
|
||||
|
||||
execution = executions[0]
|
||||
|
||||
# Use add_graph_execution in requeue mode
|
||||
await add_graph_execution(
|
||||
graph_id=execution.graph_id,
|
||||
user_id=execution.user_id,
|
||||
graph_version=execution.graph_version,
|
||||
graph_exec_id=request.execution_id, # Requeue existing execution
|
||||
)
|
||||
|
||||
return RequeueExecutionResponse(
|
||||
success=True,
|
||||
requeued_count=1,
|
||||
message="Execution requeued successfully",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/requeue-bulk",
|
||||
response_model=RequeueExecutionResponse,
|
||||
summary="Requeue Multiple Stuck Executions",
|
||||
)
|
||||
async def requeue_multiple_executions(
|
||||
request: StopExecutionsRequest, # Reuse same request model (has execution_ids)
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Requeue multiple stuck QUEUED executions (admin only).
|
||||
|
||||
Uses add_graph_execution with existing graph_exec_id to requeue.
|
||||
|
||||
⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits.
|
||||
|
||||
Args:
|
||||
request: Contains list of execution_ids to requeue
|
||||
|
||||
Returns:
|
||||
Number of executions requeued and success message
|
||||
"""
|
||||
logger.info(
|
||||
f"Admin {user.user_id} requeueing {len(request.execution_ids)} executions"
|
||||
)
|
||||
|
||||
# Get executions by ID list (must be QUEUED)
|
||||
executions = await get_graph_executions(
|
||||
execution_ids=request.execution_ids,
|
||||
statuses=[AgentExecutionStatus.QUEUED],
|
||||
)
|
||||
|
||||
if not executions:
|
||||
return RequeueExecutionResponse(
|
||||
success=False,
|
||||
requeued_count=0,
|
||||
message="No QUEUED executions found to requeue",
|
||||
)
|
||||
|
||||
# Requeue all executions in parallel using add_graph_execution
|
||||
async def requeue_one(exec) -> bool:
|
||||
try:
|
||||
await add_graph_execution(
|
||||
graph_id=exec.graph_id,
|
||||
user_id=exec.user_id,
|
||||
graph_version=exec.graph_version,
|
||||
graph_exec_id=exec.id, # Requeue existing
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to requeue {exec.id}: {e}")
|
||||
return False
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[requeue_one(exec) for exec in executions], return_exceptions=False
|
||||
)
|
||||
|
||||
requeued_count = sum(1 for success in results if success)
|
||||
|
||||
return RequeueExecutionResponse(
|
||||
success=requeued_count > 0,
|
||||
requeued_count=requeued_count,
|
||||
message=f"Requeued {requeued_count} of {len(request.execution_ids)} executions",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/stop",
|
||||
response_model=StopExecutionResponse,
|
||||
summary="Stop Single Execution",
|
||||
)
|
||||
async def stop_single_execution(
|
||||
request: StopExecutionRequest,
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Stop a single execution (admin only).
|
||||
|
||||
Uses robust stop_graph_execution which cascades to children and waits for termination.
|
||||
|
||||
Args:
|
||||
request: Contains execution_id to stop
|
||||
|
||||
Returns:
|
||||
Success status and message
|
||||
"""
|
||||
logger.info(f"Admin {user.user_id} stopping execution {request.execution_id}")
|
||||
|
||||
# Get the execution to find its owner user_id (required by stop_graph_execution)
|
||||
executions = await get_graph_executions(
|
||||
graph_exec_id=request.execution_id,
|
||||
)
|
||||
|
||||
if not executions:
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
execution = executions[0]
|
||||
|
||||
# Use robust stop_graph_execution (cascades to children, waits for termination)
|
||||
await stop_graph_execution(
|
||||
user_id=execution.user_id,
|
||||
graph_exec_id=request.execution_id,
|
||||
wait_timeout=15.0,
|
||||
cascade=True,
|
||||
)
|
||||
|
||||
return StopExecutionResponse(
|
||||
success=True,
|
||||
stopped_count=1,
|
||||
message="Execution stopped successfully",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/stop-bulk",
|
||||
response_model=StopExecutionResponse,
|
||||
summary="Stop Multiple Executions",
|
||||
)
|
||||
async def stop_multiple_executions(
|
||||
request: StopExecutionsRequest,
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Stop multiple active executions (admin only).
|
||||
|
||||
Uses robust stop_graph_execution which cascades to children and waits for termination.
|
||||
|
||||
Args:
|
||||
request: Contains list of execution_ids to stop
|
||||
|
||||
Returns:
|
||||
Number of executions stopped and success message
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
f"Admin {user.user_id} stopping {len(request.execution_ids)} executions"
|
||||
)
|
||||
|
||||
# Get executions by ID list
|
||||
executions = await get_graph_executions(
|
||||
execution_ids=request.execution_ids,
|
||||
)
|
||||
|
||||
if not executions:
|
||||
return StopExecutionResponse(
|
||||
success=False,
|
||||
stopped_count=0,
|
||||
message="No executions found",
|
||||
)
|
||||
|
||||
# Stop all executions in parallel using robust stop_graph_execution
|
||||
async def stop_one(exec) -> bool:
|
||||
try:
|
||||
await stop_graph_execution(
|
||||
user_id=exec.user_id,
|
||||
graph_exec_id=exec.id,
|
||||
wait_timeout=15.0,
|
||||
cascade=True,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop execution {exec.id}: {e}")
|
||||
return False
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[stop_one(exec) for exec in executions], return_exceptions=False
|
||||
)
|
||||
|
||||
stopped_count = sum(1 for success in results if success)
|
||||
|
||||
return StopExecutionResponse(
|
||||
success=stopped_count > 0,
|
||||
stopped_count=stopped_count,
|
||||
message=f"Stopped {stopped_count} of {len(request.execution_ids)} executions",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/cleanup-orphaned",
|
||||
response_model=StopExecutionResponse,
|
||||
summary="Cleanup Orphaned Executions",
|
||||
)
|
||||
async def cleanup_orphaned_executions(
|
||||
request: StopExecutionsRequest,
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Cleanup orphaned executions by directly updating DB status (admin only).
|
||||
For executions in DB but not actually running in executor (old/stale records).
|
||||
|
||||
Args:
|
||||
request: Contains list of execution_ids to cleanup
|
||||
|
||||
Returns:
|
||||
Number of executions cleaned up and success message
|
||||
"""
|
||||
logger.info(
|
||||
f"Admin {user.user_id} cleaning up {len(request.execution_ids)} orphaned executions"
|
||||
)
|
||||
|
||||
cleaned_count = await cleanup_orphaned_executions_bulk(
|
||||
request.execution_ids, user.user_id
|
||||
)
|
||||
|
||||
return StopExecutionResponse(
|
||||
success=cleaned_count > 0,
|
||||
stopped_count=cleaned_count,
|
||||
message=f"Cleaned up {cleaned_count} of {len(request.execution_ids)} orphaned executions",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SCHEDULE DIAGNOSTICS ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class SchedulesListResponse(BaseModel):
|
||||
"""Response model for list of schedules"""
|
||||
|
||||
schedules: List[ScheduleDetail]
|
||||
total: int
|
||||
|
||||
|
||||
class OrphanedSchedulesListResponse(BaseModel):
|
||||
"""Response model for list of orphaned schedules"""
|
||||
|
||||
schedules: List[OrphanedScheduleDetail]
|
||||
total: int
|
||||
|
||||
|
||||
class ScheduleCleanupRequest(BaseModel):
|
||||
"""Request model for cleaning up schedules"""
|
||||
|
||||
schedule_ids: List[str]
|
||||
|
||||
|
||||
class ScheduleCleanupResponse(BaseModel):
|
||||
"""Response model for schedule cleanup operations"""
|
||||
|
||||
success: bool
|
||||
deleted_count: int = 0
|
||||
message: str
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/schedules",
|
||||
response_model=ScheduleHealthMetrics,
|
||||
summary="Get Schedule Diagnostics",
|
||||
)
|
||||
async def get_schedule_diagnostics_endpoint():
|
||||
"""
|
||||
Get comprehensive diagnostic information about schedule health.
|
||||
|
||||
Returns schedule metrics including:
|
||||
- Total schedules (user vs system)
|
||||
- Orphaned schedules by category
|
||||
- Upcoming executions
|
||||
"""
|
||||
logger.info("Getting schedule diagnostics")
|
||||
|
||||
diagnostics = await get_schedule_health_metrics()
|
||||
|
||||
logger.info(
|
||||
f"Schedule diagnostics: total={diagnostics.total_schedules}, "
|
||||
f"user={diagnostics.user_schedules}, "
|
||||
f"orphaned={diagnostics.total_orphaned}"
|
||||
)
|
||||
|
||||
return diagnostics
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/schedules/all",
|
||||
response_model=SchedulesListResponse,
|
||||
summary="List All User Schedules",
|
||||
)
|
||||
async def list_all_schedules(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Get detailed list of all user schedules (excludes system monitoring jobs).
|
||||
|
||||
Args:
|
||||
limit: Maximum number of schedules to return (default 100)
|
||||
offset: Number of schedules to skip (default 0)
|
||||
|
||||
Returns:
|
||||
List of schedules with details
|
||||
"""
|
||||
logger.info(f"Listing all schedules (limit={limit}, offset={offset})")
|
||||
|
||||
schedules = await get_all_schedules_details(limit=limit, offset=offset)
|
||||
|
||||
# Get total count
|
||||
diagnostics = await get_schedule_health_metrics()
|
||||
total = diagnostics.user_schedules
|
||||
|
||||
return SchedulesListResponse(schedules=schedules, total=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/schedules/orphaned",
|
||||
response_model=OrphanedSchedulesListResponse,
|
||||
summary="List Orphaned Schedules",
|
||||
)
|
||||
async def list_orphaned_schedules():
|
||||
"""
|
||||
Get detailed list of orphaned schedules with orphan reasons.
|
||||
|
||||
Returns:
|
||||
List of orphaned schedules categorized by orphan type
|
||||
"""
|
||||
logger.info("Listing orphaned schedules")
|
||||
|
||||
schedules = await get_orphaned_schedules_details()
|
||||
|
||||
return OrphanedSchedulesListResponse(schedules=schedules, total=len(schedules))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/schedules/cleanup-orphaned",
|
||||
response_model=ScheduleCleanupResponse,
|
||||
summary="Cleanup Orphaned Schedules",
|
||||
)
|
||||
async def cleanup_orphaned_schedules(
|
||||
request: ScheduleCleanupRequest,
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Cleanup orphaned schedules by deleting from scheduler (admin only).
|
||||
|
||||
Args:
|
||||
request: Contains list of schedule_ids to delete
|
||||
|
||||
Returns:
|
||||
Number of schedules deleted and success message
|
||||
"""
|
||||
logger.info(
|
||||
f"Admin {user.user_id} cleaning up {len(request.schedule_ids)} orphaned schedules"
|
||||
)
|
||||
|
||||
deleted_count = await cleanup_orphaned_schedules_bulk(
|
||||
request.schedule_ids, user.user_id
|
||||
)
|
||||
|
||||
return ScheduleCleanupResponse(
|
||||
success=deleted_count > 0,
|
||||
deleted_count=deleted_count,
|
||||
message=f"Deleted {deleted_count} of {len(request.schedule_ids)} orphaned schedules",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/stop-all-long-running",
|
||||
response_model=StopExecutionResponse,
|
||||
summary="Stop ALL Long-Running Executions",
|
||||
)
|
||||
async def stop_all_long_running_executions_endpoint(
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Stop ALL long-running executions (RUNNING >24h) by sending cancel signals (admin only).
|
||||
Operates on entire dataset, not limited to pagination.
|
||||
|
||||
Returns:
|
||||
Number of executions stopped and success message
|
||||
"""
|
||||
logger.info(f"Admin {user.user_id} stopping ALL long-running executions")
|
||||
|
||||
stopped_count = await stop_all_long_running_executions(user.user_id)
|
||||
|
||||
return StopExecutionResponse(
|
||||
success=stopped_count > 0,
|
||||
stopped_count=stopped_count,
|
||||
message=f"Stopped {stopped_count} long-running executions",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/cleanup-all-orphaned",
|
||||
response_model=StopExecutionResponse,
|
||||
summary="Cleanup ALL Orphaned Executions",
|
||||
)
|
||||
async def cleanup_all_orphaned_executions(
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Cleanup ALL orphaned executions (>24h old) by directly updating DB status.
|
||||
Operates on all executions, not just paginated results.
|
||||
|
||||
Returns:
|
||||
Number of executions cleaned up and success message
|
||||
"""
|
||||
logger.info(f"Admin {user.user_id} cleaning up ALL orphaned executions")
|
||||
|
||||
# Fetch all orphaned execution IDs
|
||||
execution_ids = await get_all_orphaned_execution_ids()
|
||||
|
||||
if not execution_ids:
|
||||
return StopExecutionResponse(
|
||||
success=True,
|
||||
stopped_count=0,
|
||||
message="No orphaned executions to cleanup",
|
||||
)
|
||||
|
||||
cleaned_count = await cleanup_orphaned_executions_bulk(execution_ids, user.user_id)
|
||||
|
||||
return StopExecutionResponse(
|
||||
success=cleaned_count > 0,
|
||||
stopped_count=cleaned_count,
|
||||
message=f"Cleaned up {cleaned_count} orphaned executions",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/cleanup-all-stuck-queued",
|
||||
response_model=StopExecutionResponse,
|
||||
summary="Cleanup ALL Stuck Queued Executions",
|
||||
)
|
||||
async def cleanup_all_stuck_queued_executions_endpoint(
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Cleanup ALL stuck queued executions (QUEUED >1h) by updating DB status (admin only).
|
||||
Operates on entire dataset, not limited to pagination.
|
||||
|
||||
Returns:
|
||||
Number of executions cleaned up and success message
|
||||
"""
|
||||
logger.info(f"Admin {user.user_id} cleaning up ALL stuck queued executions")
|
||||
|
||||
cleaned_count = await cleanup_all_stuck_queued_executions(user.user_id)
|
||||
|
||||
return StopExecutionResponse(
|
||||
success=cleaned_count > 0,
|
||||
stopped_count=cleaned_count,
|
||||
message=f"Cleaned up {cleaned_count} stuck queued executions",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/requeue-all-stuck",
|
||||
response_model=RequeueExecutionResponse,
|
||||
summary="Requeue ALL Stuck Queued Executions",
|
||||
)
|
||||
async def requeue_all_stuck_executions(
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Requeue ALL stuck queued executions (QUEUED >1h) by publishing to RabbitMQ.
|
||||
Operates on all executions, not just paginated results.
|
||||
|
||||
Uses add_graph_execution with existing graph_exec_id to requeue.
|
||||
|
||||
⚠️ WARNING: This will re-execute ALL stuck executions and may cost significant credits.
|
||||
|
||||
Returns:
|
||||
Number of executions requeued and success message
|
||||
"""
|
||||
logger.info(f"Admin {user.user_id} requeueing ALL stuck queued executions")
|
||||
|
||||
# Fetch all stuck queued execution IDs
|
||||
execution_ids = await get_all_stuck_queued_execution_ids()
|
||||
|
||||
if not execution_ids:
|
||||
return RequeueExecutionResponse(
|
||||
success=True,
|
||||
requeued_count=0,
|
||||
message="No stuck queued executions to requeue",
|
||||
)
|
||||
|
||||
# Get stuck executions by ID list (must be QUEUED)
|
||||
executions = await get_graph_executions(
|
||||
execution_ids=execution_ids,
|
||||
statuses=[AgentExecutionStatus.QUEUED],
|
||||
)
|
||||
|
||||
# Requeue all in parallel using add_graph_execution
|
||||
async def requeue_one(exec) -> bool:
|
||||
try:
|
||||
await add_graph_execution(
|
||||
graph_id=exec.graph_id,
|
||||
user_id=exec.user_id,
|
||||
graph_version=exec.graph_version,
|
||||
graph_exec_id=exec.id, # Requeue existing
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to requeue {exec.id}: {e}")
|
||||
return False
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[requeue_one(exec) for exec in executions], return_exceptions=False
|
||||
)
|
||||
|
||||
requeued_count = sum(1 for success in results if success)
|
||||
|
||||
return RequeueExecutionResponse(
|
||||
success=requeued_count > 0,
|
||||
requeued_count=requeued_count,
|
||||
message=f"Requeued {requeued_count} stuck executions",
|
||||
)
|
||||
@@ -0,0 +1,889 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
|
||||
import backend.api.features.admin.diagnostics_admin_routes as diagnostics_admin_routes
|
||||
from backend.data.diagnostics import (
|
||||
AgentDiagnosticsSummary,
|
||||
ExecutionDiagnosticsSummary,
|
||||
FailedExecutionDetail,
|
||||
OrphanedScheduleDetail,
|
||||
RunningExecutionDetail,
|
||||
ScheduleDetail,
|
||||
ScheduleHealthMetrics,
|
||||
)
|
||||
from backend.data.execution import GraphExecutionMeta
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(diagnostics_admin_routes.router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_execution_diagnostics_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
):
|
||||
"""Test fetching execution diagnostics with invalid state detection"""
|
||||
mock_diagnostics = ExecutionDiagnosticsSummary(
|
||||
running_count=10,
|
||||
queued_db_count=5,
|
||||
rabbitmq_queue_depth=3,
|
||||
cancel_queue_depth=0,
|
||||
orphaned_running=2,
|
||||
orphaned_queued=1,
|
||||
failed_count_1h=5,
|
||||
failed_count_24h=20,
|
||||
failure_rate_24h=0.83,
|
||||
stuck_running_24h=1,
|
||||
stuck_running_1h=3,
|
||||
oldest_running_hours=26.5,
|
||||
stuck_queued_1h=2,
|
||||
queued_never_started=1,
|
||||
invalid_queued_with_start=1, # New invalid state
|
||||
invalid_running_without_start=1, # New invalid state
|
||||
completed_1h=50,
|
||||
completed_24h=1200,
|
||||
throughput_per_hour=50.0,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||
return_value=mock_diagnostics,
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/executions")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Verify new invalid state fields are included
|
||||
assert data["invalid_queued_with_start"] == 1
|
||||
assert data["invalid_running_without_start"] == 1
|
||||
# Verify all expected fields present
|
||||
assert "running_executions" in data
|
||||
assert "orphaned_running" in data
|
||||
assert "failed_count_24h" in data
|
||||
|
||||
|
||||
def test_list_invalid_executions(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
):
|
||||
"""Test listing executions in invalid states (read-only endpoint)"""
|
||||
mock_invalid_executions = [
|
||||
RunningExecutionDetail(
|
||||
execution_id="exec-invalid-1",
|
||||
graph_id="graph-123",
|
||||
graph_name="Test Graph",
|
||||
graph_version=1,
|
||||
user_id="user-123",
|
||||
user_email="test@example.com",
|
||||
status="QUEUED",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
started_at=datetime.now(
|
||||
timezone.utc
|
||||
), # QUEUED but has startedAt - INVALID!
|
||||
queue_status=None,
|
||||
),
|
||||
RunningExecutionDetail(
|
||||
execution_id="exec-invalid-2",
|
||||
graph_id="graph-456",
|
||||
graph_name="Another Graph",
|
||||
graph_version=2,
|
||||
user_id="user-456",
|
||||
user_email="user@example.com",
|
||||
status="RUNNING",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
started_at=None, # RUNNING but no startedAt - INVALID!
|
||||
queue_status=None,
|
||||
),
|
||||
]
|
||||
|
||||
mock_diagnostics = ExecutionDiagnosticsSummary(
|
||||
running_count=10,
|
||||
queued_db_count=5,
|
||||
rabbitmq_queue_depth=3,
|
||||
cancel_queue_depth=0,
|
||||
orphaned_running=0,
|
||||
orphaned_queued=0,
|
||||
failed_count_1h=0,
|
||||
failed_count_24h=0,
|
||||
failure_rate_24h=0.0,
|
||||
stuck_running_24h=0,
|
||||
stuck_running_1h=0,
|
||||
oldest_running_hours=None,
|
||||
stuck_queued_1h=0,
|
||||
queued_never_started=0,
|
||||
invalid_queued_with_start=1,
|
||||
invalid_running_without_start=1,
|
||||
completed_1h=0,
|
||||
completed_24h=0,
|
||||
throughput_per_hour=0.0,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_invalid_executions_details",
|
||||
return_value=mock_invalid_executions,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||
return_value=mock_diagnostics,
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/executions/invalid?limit=100&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 2 # Sum of both invalid state types
|
||||
assert len(data["executions"]) == 2
|
||||
# Verify both types of invalid states are returned
|
||||
assert data["executions"][0]["execution_id"] in [
|
||||
"exec-invalid-1",
|
||||
"exec-invalid-2",
|
||||
]
|
||||
assert data["executions"][1]["execution_id"] in [
|
||||
"exec-invalid-1",
|
||||
"exec-invalid-2",
|
||||
]
|
||||
|
||||
|
||||
def test_requeue_single_execution_with_add_graph_execution(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
admin_user_id: str,
|
||||
):
|
||||
"""Test requeueing uses add_graph_execution in requeue mode"""
|
||||
mock_exec_meta = GraphExecutionMeta(
|
||||
id="exec-stuck-123",
|
||||
user_id="user-123",
|
||||
graph_id="graph-456",
|
||||
graph_version=1,
|
||||
inputs=None,
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
preset_id=None,
|
||||
status=AgentExecutionStatus.QUEUED,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=[mock_exec_meta],
|
||||
)
|
||||
|
||||
mock_add_graph_execution = mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.add_graph_execution",
|
||||
return_value=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/requeue",
|
||||
json={"execution_id": "exec-stuck-123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["requeued_count"] == 1
|
||||
|
||||
# Verify it used add_graph_execution in requeue mode
|
||||
mock_add_graph_execution.assert_called_once()
|
||||
call_kwargs = mock_add_graph_execution.call_args.kwargs
|
||||
assert call_kwargs["graph_exec_id"] == "exec-stuck-123" # Requeue mode!
|
||||
assert call_kwargs["graph_id"] == "graph-456"
|
||||
assert call_kwargs["user_id"] == "user-123"
|
||||
|
||||
|
||||
def test_stop_single_execution_with_stop_graph_execution(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
admin_user_id: str,
|
||||
):
|
||||
"""Test stopping uses robust stop_graph_execution"""
|
||||
mock_exec_meta = GraphExecutionMeta(
|
||||
id="exec-running-123",
|
||||
user_id="user-789",
|
||||
graph_id="graph-999",
|
||||
graph_version=2,
|
||||
inputs=None,
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
preset_id=None,
|
||||
status=AgentExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=[mock_exec_meta],
|
||||
)
|
||||
|
||||
mock_stop_graph_execution = mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.stop_graph_execution",
|
||||
return_value=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/stop",
|
||||
json={"execution_id": "exec-running-123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 1
|
||||
|
||||
# Verify it used stop_graph_execution with cascade
|
||||
mock_stop_graph_execution.assert_called_once()
|
||||
call_kwargs = mock_stop_graph_execution.call_args.kwargs
|
||||
assert call_kwargs["graph_exec_id"] == "exec-running-123"
|
||||
assert call_kwargs["user_id"] == "user-789"
|
||||
assert call_kwargs["cascade"] is True # Stops children too!
|
||||
assert call_kwargs["wait_timeout"] == 15.0
|
||||
|
||||
|
||||
def test_requeue_not_queued_execution_fails(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
):
|
||||
"""Test that requeue fails if execution is not in QUEUED status"""
|
||||
# Mock an execution that's RUNNING (not QUEUED)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=[], # No QUEUED executions found
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/requeue",
|
||||
json={"execution_id": "exec-running-123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found or not in QUEUED status" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_list_invalid_executions_no_bulk_actions(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
):
|
||||
"""Verify invalid executions endpoint is read-only (no bulk actions)"""
|
||||
# This is a documentation test - the endpoint exists but should not
|
||||
# have corresponding cleanup/stop/requeue endpoints
|
||||
|
||||
# These endpoints should NOT exist for invalid states:
|
||||
invalid_bulk_endpoints = [
|
||||
"/admin/diagnostics/executions/cleanup-invalid",
|
||||
"/admin/diagnostics/executions/stop-invalid",
|
||||
"/admin/diagnostics/executions/requeue-invalid",
|
||||
]
|
||||
|
||||
for endpoint in invalid_bulk_endpoints:
|
||||
response = client.post(endpoint, json={"execution_ids": ["test"]})
|
||||
assert response.status_code == 404, f"{endpoint} should not exist (read-only)"
|
||||
|
||||
|
||||
def test_execution_ids_filter_efficiency(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
):
|
||||
"""Test that bulk operations use efficient execution_ids filter"""
|
||||
mock_exec_metas = [
|
||||
GraphExecutionMeta(
|
||||
id=f"exec-{i}",
|
||||
user_id=f"user-{i}",
|
||||
graph_id="graph-123",
|
||||
graph_version=1,
|
||||
inputs=None,
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
preset_id=None,
|
||||
status=AgentExecutionStatus.QUEUED,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
mock_get_graph_executions = mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=mock_exec_metas,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.add_graph_execution",
|
||||
return_value=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/requeue-bulk",
|
||||
json={"execution_ids": ["exec-0", "exec-1", "exec-2"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify it used execution_ids filter (not fetching all queued)
|
||||
mock_get_graph_executions.assert_called_once()
|
||||
call_kwargs = mock_get_graph_executions.call_args.kwargs
|
||||
assert "execution_ids" in call_kwargs
|
||||
assert call_kwargs["execution_ids"] == ["exec-0", "exec-1", "exec-2"]
|
||||
assert call_kwargs["statuses"] == [AgentExecutionStatus.QUEUED]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: reusable mock diagnostics summary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_diagnostics(**overrides) -> ExecutionDiagnosticsSummary:
|
||||
defaults = dict(
|
||||
running_count=10,
|
||||
queued_db_count=5,
|
||||
rabbitmq_queue_depth=3,
|
||||
cancel_queue_depth=0,
|
||||
orphaned_running=2,
|
||||
orphaned_queued=1,
|
||||
failed_count_1h=5,
|
||||
failed_count_24h=20,
|
||||
failure_rate_24h=0.83,
|
||||
stuck_running_24h=3,
|
||||
stuck_running_1h=5,
|
||||
oldest_running_hours=26.5,
|
||||
stuck_queued_1h=2,
|
||||
queued_never_started=1,
|
||||
invalid_queued_with_start=1,
|
||||
invalid_running_without_start=1,
|
||||
completed_1h=50,
|
||||
completed_24h=1200,
|
||||
throughput_per_hour=50.0,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return ExecutionDiagnosticsSummary(**defaults)
|
||||
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
def _make_mock_execution(
|
||||
exec_id: str = "exec-1",
|
||||
status: str = "RUNNING",
|
||||
started_at: datetime | None | object = _SENTINEL,
|
||||
) -> RunningExecutionDetail:
|
||||
return RunningExecutionDetail(
|
||||
execution_id=exec_id,
|
||||
graph_id="graph-123",
|
||||
graph_name="Test Graph",
|
||||
graph_version=1,
|
||||
user_id="user-123",
|
||||
user_email="test@example.com",
|
||||
status=status,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
started_at=(
|
||||
datetime.now(timezone.utc) if started_at is _SENTINEL else started_at
|
||||
),
|
||||
queue_status=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_mock_failed_execution(
|
||||
exec_id: str = "exec-fail-1",
|
||||
) -> FailedExecutionDetail:
|
||||
return FailedExecutionDetail(
|
||||
execution_id=exec_id,
|
||||
graph_id="graph-123",
|
||||
graph_name="Test Graph",
|
||||
graph_version=1,
|
||||
user_id="user-123",
|
||||
user_email="test@example.com",
|
||||
status="FAILED",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
started_at=datetime.now(timezone.utc),
|
||||
failed_at=datetime.now(timezone.utc),
|
||||
error_message="Something went wrong",
|
||||
)
|
||||
|
||||
|
||||
def _make_mock_schedule_health(**overrides) -> ScheduleHealthMetrics:
|
||||
defaults = dict(
|
||||
total_schedules=15,
|
||||
user_schedules=10,
|
||||
system_schedules=5,
|
||||
orphaned_deleted_graph=2,
|
||||
orphaned_no_library_access=1,
|
||||
orphaned_invalid_credentials=0,
|
||||
orphaned_validation_failed=0,
|
||||
total_orphaned=3,
|
||||
schedules_next_hour=4,
|
||||
schedules_next_24h=8,
|
||||
total_runs_next_hour=12,
|
||||
total_runs_next_24h=48,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return ScheduleHealthMetrics(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET endpoints: execution list variants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_list_running_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_execs = [
|
||||
_make_mock_execution("exec-run-1"),
|
||||
_make_mock_execution("exec-run-2"),
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_running_executions_details",
|
||||
return_value=mock_execs,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||
return_value=_make_mock_diagnostics(),
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/executions/running?limit=50&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 15 # running_count(10) + queued_db_count(5)
|
||||
assert len(data["executions"]) == 2
|
||||
assert data["executions"][0]["execution_id"] == "exec-run-1"
|
||||
|
||||
|
||||
def test_list_orphaned_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_execs = [_make_mock_execution("exec-orphan-1", status="RUNNING")]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_orphaned_executions_details",
|
||||
return_value=mock_execs,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||
return_value=_make_mock_diagnostics(),
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/executions/orphaned?limit=50&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 3 # orphaned_running(2) + orphaned_queued(1)
|
||||
assert len(data["executions"]) == 1
|
||||
|
||||
|
||||
def test_list_failed_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_execs = [_make_mock_failed_execution("exec-fail-1")]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_failed_executions_details",
|
||||
return_value=mock_execs,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_failed_executions_count",
|
||||
return_value=42,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/diagnostics/executions/failed?limit=50&offset=0&hours=24"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 42
|
||||
assert len(data["executions"]) == 1
|
||||
assert data["executions"][0]["error_message"] == "Something went wrong"
|
||||
|
||||
|
||||
def test_list_long_running_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_execs = [_make_mock_execution("exec-long-1")]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_long_running_executions_details",
|
||||
return_value=mock_execs,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||
return_value=_make_mock_diagnostics(),
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/diagnostics/executions/long-running?limit=50&offset=0"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 3 # stuck_running_24h
|
||||
assert len(data["executions"]) == 1
|
||||
|
||||
|
||||
def test_list_stuck_queued_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_execs = [
|
||||
_make_mock_execution("exec-stuck-1", status="QUEUED", started_at=None)
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_stuck_queued_executions_details",
|
||||
return_value=mock_execs,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||
return_value=_make_mock_diagnostics(),
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/diagnostics/executions/stuck-queued?limit=50&offset=0"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 2 # stuck_queued_1h
|
||||
assert len(data["executions"]) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET endpoints: agent + schedule diagnostics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_agent_diagnostics(mocker: pytest_mock.MockFixture):
|
||||
mock_diag = AgentDiagnosticsSummary(
|
||||
agents_with_active_executions=7,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_agent_diagnostics",
|
||||
return_value=mock_diag,
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/agents")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["agents_with_active_executions"] == 7
|
||||
|
||||
|
||||
def test_get_schedule_diagnostics(mocker: pytest_mock.MockFixture):
|
||||
mock_metrics = _make_mock_schedule_health()
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_schedule_health_metrics",
|
||||
return_value=mock_metrics,
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/schedules")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_schedules"] == 10
|
||||
assert data["total_orphaned"] == 3
|
||||
assert data["total_runs_next_hour"] == 12
|
||||
|
||||
|
||||
def test_list_all_schedules(mocker: pytest_mock.MockFixture):
|
||||
mock_schedules = [
|
||||
ScheduleDetail(
|
||||
schedule_id="sched-1",
|
||||
schedule_name="Daily Run",
|
||||
graph_id="graph-1",
|
||||
graph_name="My Agent",
|
||||
graph_version=1,
|
||||
user_id="user-1",
|
||||
user_email="alice@example.com",
|
||||
cron="0 9 * * *",
|
||||
timezone="UTC",
|
||||
next_run_time=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_all_schedules_details",
|
||||
return_value=mock_schedules,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_schedule_health_metrics",
|
||||
return_value=_make_mock_schedule_health(),
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/schedules/all?limit=50&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 10
|
||||
assert len(data["schedules"]) == 1
|
||||
assert data["schedules"][0]["schedule_name"] == "Daily Run"
|
||||
|
||||
|
||||
def test_list_orphaned_schedules(mocker: pytest_mock.MockFixture):
|
||||
mock_orphans = [
|
||||
OrphanedScheduleDetail(
|
||||
schedule_id="sched-orphan-1",
|
||||
schedule_name="Ghost Schedule",
|
||||
graph_id="graph-deleted",
|
||||
graph_version=1,
|
||||
user_id="user-1",
|
||||
orphan_reason="deleted_graph",
|
||||
error_detail=None,
|
||||
next_run_time=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_orphaned_schedules_details",
|
||||
return_value=mock_orphans,
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/schedules/orphaned")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert data["schedules"][0]["orphan_reason"] == "deleted_graph"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST endpoints: bulk stop, cleanup, requeue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_stop_multiple_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_exec_metas = [
|
||||
GraphExecutionMeta(
|
||||
id=f"exec-{i}",
|
||||
user_id=f"user-{i}",
|
||||
graph_id="graph-123",
|
||||
graph_version=1,
|
||||
inputs=None,
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
preset_id=None,
|
||||
status=AgentExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=None,
|
||||
stats=None,
|
||||
)
|
||||
for i in range(2)
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=mock_exec_metas,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.stop_graph_execution",
|
||||
return_value=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/stop-bulk",
|
||||
json={"execution_ids": ["exec-0", "exec-1"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 2
|
||||
|
||||
|
||||
def test_stop_multiple_executions_none_found(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/stop-bulk",
|
||||
json={"execution_ids": ["nonexistent"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
assert data["stopped_count"] == 0
|
||||
|
||||
|
||||
def test_cleanup_orphaned_executions(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_executions_bulk",
|
||||
return_value=3,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/cleanup-orphaned",
|
||||
json={"execution_ids": ["exec-1", "exec-2", "exec-3"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 3
|
||||
|
||||
|
||||
def test_cleanup_orphaned_schedules(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_schedules_bulk",
|
||||
return_value=2,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/schedules/cleanup-orphaned",
|
||||
json={"schedule_ids": ["sched-1", "sched-2"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["deleted_count"] == 2
|
||||
|
||||
|
||||
def test_stop_all_long_running_executions(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.stop_all_long_running_executions",
|
||||
return_value=5,
|
||||
)
|
||||
|
||||
response = client.post("/admin/diagnostics/executions/stop-all-long-running")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 5
|
||||
|
||||
|
||||
def test_cleanup_all_orphaned_executions(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_all_orphaned_execution_ids",
|
||||
return_value=["exec-1", "exec-2"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_executions_bulk",
|
||||
return_value=2,
|
||||
)
|
||||
|
||||
response = client.post("/admin/diagnostics/executions/cleanup-all-orphaned")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 2
|
||||
|
||||
|
||||
def test_cleanup_all_orphaned_executions_none(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_all_orphaned_execution_ids",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.post("/admin/diagnostics/executions/cleanup-all-orphaned")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 0
|
||||
assert "No orphaned" in data["message"]
|
||||
|
||||
|
||||
def test_cleanup_all_stuck_queued_executions(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.cleanup_all_stuck_queued_executions",
|
||||
return_value=4,
|
||||
)
|
||||
|
||||
response = client.post("/admin/diagnostics/executions/cleanup-all-stuck-queued")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 4
|
||||
|
||||
|
||||
def test_requeue_all_stuck_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_exec_metas = [
|
||||
GraphExecutionMeta(
|
||||
id=f"exec-stuck-{i}",
|
||||
user_id=f"user-{i}",
|
||||
graph_id="graph-123",
|
||||
graph_version=1,
|
||||
inputs=None,
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
preset_id=None,
|
||||
status=AgentExecutionStatus.QUEUED,
|
||||
started_at=None,
|
||||
ended_at=None,
|
||||
stats=None,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_all_stuck_queued_execution_ids",
|
||||
return_value=["exec-stuck-0", "exec-stuck-1", "exec-stuck-2"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=mock_exec_metas,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.add_graph_execution",
|
||||
return_value=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.post("/admin/diagnostics/executions/requeue-all-stuck")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["requeued_count"] == 3
|
||||
|
||||
|
||||
def test_requeue_all_stuck_executions_none(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_all_stuck_queued_execution_ids",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.post("/admin/diagnostics/executions/requeue-all-stuck")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["requeued_count"] == 0
|
||||
assert "No stuck" in data["message"]
|
||||
|
||||
|
||||
def test_requeue_bulk_none_found(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/requeue-bulk",
|
||||
json={"execution_ids": ["nonexistent"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
assert data["requeued_count"] == 0
|
||||
|
||||
|
||||
def test_stop_single_execution_not_found(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/stop",
|
||||
json={"execution_id": "nonexistent"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"]
|
||||
@@ -14,3 +14,70 @@ class UserHistoryResponse(BaseModel):
|
||||
class AddUserCreditsResponse(BaseModel):
|
||||
new_balance: int
|
||||
transaction_key: str
|
||||
|
||||
|
||||
class ExecutionDiagnosticsResponse(BaseModel):
|
||||
"""Response model for execution diagnostics"""
|
||||
|
||||
# Current execution state
|
||||
running_executions: int
|
||||
queued_executions_db: int
|
||||
queued_executions_rabbitmq: int
|
||||
cancel_queue_depth: int
|
||||
|
||||
# Orphaned execution detection
|
||||
orphaned_running: int
|
||||
orphaned_queued: int
|
||||
|
||||
# Failure metrics
|
||||
failed_count_1h: int
|
||||
failed_count_24h: int
|
||||
failure_rate_24h: float
|
||||
|
||||
# Long-running detection
|
||||
stuck_running_24h: int
|
||||
stuck_running_1h: int
|
||||
oldest_running_hours: float | None
|
||||
|
||||
# Stuck queued detection
|
||||
stuck_queued_1h: int
|
||||
queued_never_started: int
|
||||
|
||||
# Invalid state detection (data corruption - no auto-actions)
|
||||
invalid_queued_with_start: int
|
||||
invalid_running_without_start: int
|
||||
|
||||
# Throughput metrics
|
||||
completed_1h: int
|
||||
completed_24h: int
|
||||
throughput_per_hour: float
|
||||
|
||||
timestamp: str
|
||||
|
||||
|
||||
class AgentDiagnosticsResponse(BaseModel):
|
||||
"""Response model for agent diagnostics"""
|
||||
|
||||
agents_with_active_executions: int
|
||||
timestamp: str
|
||||
|
||||
|
||||
class ScheduleHealthMetrics(BaseModel):
|
||||
"""Response model for schedule diagnostics"""
|
||||
|
||||
total_schedules: int
|
||||
user_schedules: int
|
||||
system_schedules: int
|
||||
|
||||
# Orphan detection
|
||||
orphaned_deleted_graph: int
|
||||
orphaned_no_library_access: int
|
||||
orphaned_invalid_credentials: int
|
||||
orphaned_validation_failed: int
|
||||
total_orphaned: int
|
||||
|
||||
# Upcoming
|
||||
schedules_next_hour: int
|
||||
schedules_next_24h: int
|
||||
|
||||
timestamp: str
|
||||
|
||||
@@ -32,10 +32,10 @@ router = APIRouter(
|
||||
class UserRateLimitResponse(BaseModel):
|
||||
user_id: str
|
||||
user_email: Optional[str] = None
|
||||
daily_token_limit: int
|
||||
weekly_token_limit: int
|
||||
daily_tokens_used: int
|
||||
weekly_tokens_used: int
|
||||
daily_cost_limit_microdollars: int
|
||||
weekly_cost_limit_microdollars: int
|
||||
daily_cost_used_microdollars: int
|
||||
weekly_cost_used_microdollars: int
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
@@ -101,17 +101,19 @@ async def get_user_rate_limit(
|
||||
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
|
||||
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
resolved_id, config.daily_token_limit, config.weekly_token_limit
|
||||
resolved_id,
|
||||
config.daily_cost_limit_microdollars,
|
||||
config.weekly_cost_limit_microdollars,
|
||||
)
|
||||
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=resolved_id,
|
||||
user_email=resolved_email,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
daily_cost_limit_microdollars=daily_limit,
|
||||
weekly_cost_limit_microdollars=weekly_limit,
|
||||
daily_cost_used_microdollars=usage.daily.used,
|
||||
weekly_cost_used_microdollars=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
@@ -141,7 +143,9 @@ async def reset_user_rate_limit(
|
||||
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
|
||||
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
user_id,
|
||||
config.daily_cost_limit_microdollars,
|
||||
config.weekly_cost_limit_microdollars,
|
||||
)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
@@ -154,10 +158,10 @@ async def reset_user_rate_limit(
|
||||
return UserRateLimitResponse(
|
||||
user_id=user_id,
|
||||
user_email=resolved_email,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
daily_cost_limit_microdollars=daily_limit,
|
||||
weekly_cost_limit_microdollars=weekly_limit,
|
||||
daily_cost_used_microdollars=usage.daily.used,
|
||||
weekly_cost_used_microdollars=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ def _patch_rate_limit_deps(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.BASIC),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
@@ -85,11 +85,11 @@ def test_get_rate_limit(
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] == _TARGET_EMAIL
|
||||
assert data["daily_token_limit"] == 2_500_000
|
||||
assert data["weekly_token_limit"] == 12_500_000
|
||||
assert data["daily_tokens_used"] == 500_000
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
assert data["tier"] == "FREE"
|
||||
assert data["daily_cost_limit_microdollars"] == 2_500_000
|
||||
assert data["weekly_cost_limit_microdollars"] == 12_500_000
|
||||
assert data["daily_cost_used_microdollars"] == 500_000
|
||||
assert data["weekly_cost_used_microdollars"] == 3_000_000
|
||||
assert data["tier"] == "BASIC"
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
@@ -117,7 +117,7 @@ def test_get_rate_limit_by_email(
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] == _TARGET_EMAIL
|
||||
assert data["daily_token_limit"] == 2_500_000
|
||||
assert data["daily_cost_limit_microdollars"] == 2_500_000
|
||||
|
||||
|
||||
def test_get_rate_limit_by_email_not_found(
|
||||
@@ -160,10 +160,10 @@ def test_reset_user_usage_daily_only(
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily_tokens_used"] == 0
|
||||
assert data["daily_cost_used_microdollars"] == 0
|
||||
# Weekly is untouched
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
assert data["tier"] == "FREE"
|
||||
assert data["weekly_cost_used_microdollars"] == 3_000_000
|
||||
assert data["tier"] == "BASIC"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
|
||||
|
||||
@@ -192,9 +192,9 @@ def test_reset_user_usage_daily_and_weekly(
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily_tokens_used"] == 0
|
||||
assert data["weekly_tokens_used"] == 0
|
||||
assert data["tier"] == "FREE"
|
||||
assert data["daily_cost_used_microdollars"] == 0
|
||||
assert data["weekly_cost_used_microdollars"] == 0
|
||||
assert data["tier"] == "BASIC"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
|
||||
|
||||
@@ -231,7 +231,7 @@ def test_get_rate_limit_email_lookup_failure(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.BASIC),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
@@ -324,7 +324,7 @@ def test_set_user_tier(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
return_value=SubscriptionTier.BASIC,
|
||||
)
|
||||
mock_set = mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
@@ -347,7 +347,7 @@ def test_set_user_tier_downgrade(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test downgrading a user's tier from PRO to FREE."""
|
||||
"""Test downgrading a user's tier from PRO to BASIC."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
@@ -365,14 +365,14 @@ def test_set_user_tier_downgrade(
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "FREE"},
|
||||
json={"user_id": target_user_id, "tier": "BASIC"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "FREE"
|
||||
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.FREE)
|
||||
assert data["tier"] == "BASIC"
|
||||
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.BASIC)
|
||||
|
||||
|
||||
def test_set_user_tier_invalid_tier(
|
||||
@@ -456,7 +456,7 @@ def test_set_user_tier_db_failure(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
return_value=SubscriptionTier.BASIC,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
|
||||
@@ -13,6 +13,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.builder_context import resolve_session_permissions
|
||||
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
|
||||
from backend.copilot.db import get_chat_messages_paginated
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
@@ -24,6 +25,7 @@ from backend.copilot.model import (
|
||||
create_chat_session,
|
||||
delete_chat_session,
|
||||
get_chat_session,
|
||||
get_or_create_builder_session,
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
@@ -34,7 +36,7 @@ from backend.copilot.pending_message_helpers import (
|
||||
)
|
||||
from backend.copilot.pending_messages import peek_pending_messages
|
||||
from backend.copilot.rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
CoPilotUsagePublic,
|
||||
RateLimitExceeded,
|
||||
acquire_reset_lock,
|
||||
check_rate_limit,
|
||||
@@ -73,6 +75,7 @@ from backend.copilot.tools.models import (
|
||||
NoResultsResponse,
|
||||
SetupRequirementsResponse,
|
||||
SuggestedGoalResponse,
|
||||
TodoWriteResponse,
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
@@ -133,7 +136,7 @@ def _strip_injected_context(message: dict) -> dict:
|
||||
class StreamChatRequest(BaseModel):
|
||||
"""Request model for streaming chat with optional context."""
|
||||
|
||||
message: str
|
||||
message: str = Field(max_length=64_000)
|
||||
is_user_message: bool = True
|
||||
context: dict[str, str] | None = None # {url: str, content: str}
|
||||
file_ids: list[str] | None = Field(
|
||||
@@ -165,15 +168,31 @@ class PeekPendingMessagesResponse(BaseModel):
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""Request model for creating a new chat session.
|
||||
"""Request model for creating (or get-or-creating) a chat session.
|
||||
|
||||
Two modes, selected by the body:
|
||||
|
||||
- Default: create a fresh session. ``dry_run`` is a **top-level**
|
||||
field — do not nest it inside ``metadata``.
|
||||
- Builder-bound: when ``builder_graph_id`` is set, the endpoint
|
||||
switches to **get-or-create** keyed on
|
||||
``(user_id, builder_graph_id)``. The builder panel calls this on
|
||||
mount so the chat persists across refreshes. Graph ownership is
|
||||
validated inside :func:`get_or_create_builder_session`. Write-side
|
||||
scope is enforced per-tool (``edit_agent`` / ``run_agent`` reject
|
||||
any ``agent_id`` other than the bound graph) and a small blacklist
|
||||
hides tools that conflict with the panel's scope
|
||||
(``create_agent`` / ``customize_agent`` / ``get_agent_building_guide``
|
||||
— see :data:`BUILDER_BLOCKED_TOOLS`). Read-side lookups
|
||||
(``find_block``, ``find_agent``, ``search_docs``, …) stay open.
|
||||
|
||||
``dry_run`` is a **top-level** field — do not nest it inside ``metadata``.
|
||||
Extra/unknown fields are rejected (422) to prevent silent mis-use.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
dry_run: bool = False
|
||||
builder_graph_id: str | None = Field(default=None, max_length=128)
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
@@ -318,29 +337,43 @@ async def create_session(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
request: CreateSessionRequest | None = None,
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
"""Create (or get-or-create) a chat session.
|
||||
|
||||
Initiates a new chat session for the authenticated user.
|
||||
Two modes, selected by the request body:
|
||||
|
||||
- Default: create a fresh session for the user. ``dry_run=True`` forces
|
||||
run_block and run_agent calls to use dry-run simulation.
|
||||
- Builder-bound: when ``builder_graph_id`` is set, get-or-create keyed
|
||||
on ``(user_id, builder_graph_id)``. Returns the existing session for
|
||||
that graph or creates one locked to it. Graph ownership is validated
|
||||
inside :func:`get_or_create_builder_session`; raises 404 on
|
||||
unauthorized access. Write-side scope is enforced per-tool
|
||||
(``edit_agent`` / ``run_agent`` reject any ``agent_id`` other than
|
||||
the bound graph) and a small blacklist hides tools that conflict
|
||||
with the panel's scope (see :data:`BUILDER_BLOCKED_TOOLS`).
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user ID parsed from the JWT (required).
|
||||
request: Optional request body. When provided, ``dry_run=True``
|
||||
forces run_block and run_agent calls to use dry-run simulation.
|
||||
request: Optional request body with ``dry_run`` and/or
|
||||
``builder_graph_id``.
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created session.
|
||||
|
||||
CreateSessionResponse: Details of the resulting session.
|
||||
"""
|
||||
dry_run = request.dry_run if request else False
|
||||
builder_graph_id = request.builder_graph_id if request else None
|
||||
|
||||
logger.info(
|
||||
f"Creating session with user_id: "
|
||||
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
|
||||
f"{', dry_run=True' if dry_run else ''}"
|
||||
f"{f', builder_graph_id={builder_graph_id}' if builder_graph_id else ''}"
|
||||
)
|
||||
|
||||
session = await create_chat_session(user_id, dry_run=dry_run)
|
||||
if builder_graph_id:
|
||||
session = await get_or_create_builder_session(user_id, builder_graph_id)
|
||||
else:
|
||||
session = await create_chat_session(user_id, dry_run=dry_run)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
@@ -536,23 +569,27 @@ async def get_session(
|
||||
)
|
||||
async def get_copilot_usage(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> CoPilotUsageStatus:
|
||||
) -> CoPilotUsagePublic:
|
||||
"""Get CoPilot usage status for the authenticated user.
|
||||
|
||||
Returns current token usage vs limits for daily and weekly windows.
|
||||
Global defaults sourced from LaunchDarkly (falling back to config).
|
||||
Includes the user's rate-limit tier.
|
||||
Returns the percentage of the daily/weekly allowance used — not the
|
||||
raw spend or cap — so clients cannot derive per-turn cost or platform
|
||||
margins. Global defaults sourced from LaunchDarkly (falling back to
|
||||
config). Includes the user's rate-limit tier.
|
||||
"""
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
user_id,
|
||||
config.daily_cost_limit_microdollars,
|
||||
config.weekly_cost_limit_microdollars,
|
||||
)
|
||||
return await get_usage_status(
|
||||
status = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_cost_limit=daily_limit,
|
||||
weekly_cost_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
tier=tier,
|
||||
)
|
||||
return CoPilotUsagePublic.from_status(status)
|
||||
|
||||
|
||||
class RateLimitResetResponse(BaseModel):
|
||||
@@ -561,7 +598,9 @@ class RateLimitResetResponse(BaseModel):
|
||||
success: bool
|
||||
credits_charged: int = Field(description="Credits charged (in cents)")
|
||||
remaining_balance: int = Field(description="Credit balance after charge (in cents)")
|
||||
usage: CoPilotUsageStatus = Field(description="Updated usage status after reset")
|
||||
usage: CoPilotUsagePublic = Field(
|
||||
description="Updated usage status after reset (percentages only)"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -585,7 +624,7 @@ async def reset_copilot_usage(
|
||||
) -> RateLimitResetResponse:
|
||||
"""Reset the daily CoPilot rate limit by spending credits.
|
||||
|
||||
Allows users who have hit their daily token limit to spend credits
|
||||
Allows users who have hit their daily cost limit to spend credits
|
||||
to reset their daily usage counter and continue working.
|
||||
Returns 400 if the feature is disabled or the user is not over the limit.
|
||||
Returns 402 if the user has insufficient credits.
|
||||
@@ -604,7 +643,9 @@ async def reset_copilot_usage(
|
||||
)
|
||||
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
user_id,
|
||||
config.daily_cost_limit_microdollars,
|
||||
config.weekly_cost_limit_microdollars,
|
||||
)
|
||||
|
||||
if daily_limit <= 0:
|
||||
@@ -641,8 +682,8 @@ async def reset_copilot_usage(
|
||||
# used for limit checks, not returned to the client.)
|
||||
usage_status = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_cost_limit=daily_limit,
|
||||
weekly_cost_limit=weekly_limit,
|
||||
tier=tier,
|
||||
)
|
||||
if daily_limit > 0 and usage_status.daily.used < daily_limit:
|
||||
@@ -677,7 +718,7 @@ async def reset_copilot_usage(
|
||||
|
||||
# Reset daily usage in Redis. If this fails, refund the credits
|
||||
# so the user is not charged for a service they did not receive.
|
||||
if not await reset_daily_usage(user_id, daily_token_limit=daily_limit):
|
||||
if not await reset_daily_usage(user_id, daily_cost_limit=daily_limit):
|
||||
# Compensate: refund the charged credits.
|
||||
refunded = False
|
||||
try:
|
||||
@@ -713,11 +754,11 @@ async def reset_copilot_usage(
|
||||
finally:
|
||||
await release_reset_lock(user_id)
|
||||
|
||||
# Return updated usage status.
|
||||
# Return updated usage status (public schema — percentages only).
|
||||
updated_usage = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_cost_limit=daily_limit,
|
||||
weekly_cost_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
tier=tier,
|
||||
)
|
||||
@@ -726,7 +767,7 @@ async def reset_copilot_usage(
|
||||
success=True,
|
||||
credits_charged=cost,
|
||||
remaining_balance=remaining,
|
||||
usage=updated_usage,
|
||||
usage=CoPilotUsagePublic.from_status(updated_usage),
|
||||
)
|
||||
|
||||
|
||||
@@ -787,7 +828,7 @@ async def cancel_session_task(
|
||||
),
|
||||
},
|
||||
404: {"description": "Session not found or access denied"},
|
||||
429: {"description": "Token rate-limit or call-frequency cap exceeded"},
|
||||
429: {"description": "Cost rate-limit or call-frequency cap exceeded"},
|
||||
},
|
||||
)
|
||||
async def stream_chat_post(
|
||||
@@ -830,7 +871,8 @@ async def stream_chat_post(
|
||||
f"user={user_id}, message_len={len(request.message)}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
await _validate_and_get_session(session_id, user_id)
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
builder_permissions = resolve_session_permissions(session)
|
||||
|
||||
# Self-defensive queue-fallback: if a turn is already running, don't race
|
||||
# it on the cluster lock — drop the message into the pending buffer and
|
||||
@@ -861,18 +903,20 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Pre-turn rate limit check (token-based).
|
||||
# Pre-turn rate limit check (cost-based, microdollars).
|
||||
# check_rate_limit short-circuits internally when both limits are 0.
|
||||
# Global defaults sourced from LaunchDarkly, falling back to config.
|
||||
if user_id:
|
||||
try:
|
||||
daily_limit, weekly_limit, _ = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
user_id,
|
||||
config.daily_cost_limit_microdollars,
|
||||
config.weekly_cost_limit_microdollars,
|
||||
)
|
||||
await check_rate_limit(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_cost_limit=daily_limit,
|
||||
weekly_cost_limit=weekly_limit,
|
||||
)
|
||||
except RateLimitExceeded as e:
|
||||
raise HTTPException(status_code=429, detail=str(e)) from e
|
||||
@@ -943,6 +987,7 @@ async def stream_chat_post(
|
||||
file_ids=sanitized_file_ids,
|
||||
mode=request.mode,
|
||||
model=request.model,
|
||||
permissions=builder_permissions,
|
||||
request_arrival_at=request_arrival_at,
|
||||
)
|
||||
else:
|
||||
@@ -1375,6 +1420,7 @@ ToolResponseUnion = (
|
||||
| MemorySearchResponse
|
||||
| MemoryForgetCandidatesResponse
|
||||
| MemoryForgetConfirmResponse
|
||||
| TodoWriteResponse
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -11,10 +11,20 @@ import pytest_mock
|
||||
from backend.api.features.chat import routes as chat_routes
|
||||
from backend.api.features.chat.routes import _strip_injected_context
|
||||
from backend.copilot.rate_limit import SubscriptionTier
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(chat_routes.router)
|
||||
|
||||
|
||||
@app.exception_handler(NotFoundError)
|
||||
async def _not_found_handler(
|
||||
request: fastapi.Request, exc: NotFoundError
|
||||
) -> fastapi.responses.JSONResponse:
|
||||
"""Mirror the production NotFoundError → 404 mapping from the REST app."""
|
||||
return fastapi.responses.JSONResponse(status_code=404, content={"detail": str(exc)})
|
||||
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
@@ -296,8 +306,8 @@ def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockerF
|
||||
|
||||
_mock_stream_internals(mocker)
|
||||
# Ensure the rate-limit branch is entered by setting a non-zero limit.
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)),
|
||||
@@ -318,8 +328,8 @@ def test_stream_chat_returns_429_on_weekly_rate_limit(
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000)
|
||||
resets_at = datetime.now(UTC) + timedelta(days=3)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
@@ -341,8 +351,8 @@ def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockerFixture):
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
side_effect=RateLimitExceeded(
|
||||
@@ -370,7 +380,7 @@ def _mock_usage(
|
||||
weekly_used: int = 2000,
|
||||
daily_limit: int = 10000,
|
||||
weekly_limit: int = 50000,
|
||||
tier: "SubscriptionTier" = SubscriptionTier.FREE,
|
||||
tier: "SubscriptionTier" = SubscriptionTier.BASIC,
|
||||
) -> AsyncMock:
|
||||
"""Mock get_usage_status and get_global_rate_limits for usage endpoint tests.
|
||||
|
||||
@@ -402,25 +412,35 @@ def test_usage_returns_daily_and_weekly(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""GET /usage returns daily and weekly usage."""
|
||||
"""GET /usage returns percentages for daily and weekly windows only.
|
||||
|
||||
The raw used/limit microdollar values MUST NOT leak — clients should not
|
||||
be able to derive per-turn cost or platform margins from the public API.
|
||||
"""
|
||||
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily"]["used"] == 500
|
||||
assert data["weekly"]["used"] == 2000
|
||||
# 500 / 10000 = 5%, 2000 / 50000 = 4%
|
||||
assert data["daily"]["percent_used"] == 5.0
|
||||
assert data["weekly"]["percent_used"] == 4.0
|
||||
# Raw spend/limit must not be exposed.
|
||||
assert "used" not in data["daily"]
|
||||
assert "limit" not in data["daily"]
|
||||
assert "used" not in data["weekly"]
|
||||
assert "limit" not in data["weekly"]
|
||||
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=10000,
|
||||
weekly_token_limit=50000,
|
||||
daily_cost_limit=10000,
|
||||
weekly_cost_limit=50000,
|
||||
rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost,
|
||||
tier=SubscriptionTier.FREE,
|
||||
tier=SubscriptionTier.BASIC,
|
||||
)
|
||||
|
||||
|
||||
@@ -438,10 +458,10 @@ def test_usage_uses_config_limits(
|
||||
assert response.status_code == 200
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=99999,
|
||||
weekly_token_limit=77777,
|
||||
daily_cost_limit=99999,
|
||||
weekly_cost_limit=77777,
|
||||
rate_limit_reset_cost=500,
|
||||
tier=SubscriptionTier.FREE,
|
||||
tier=SubscriptionTier.BASIC,
|
||||
)
|
||||
|
||||
|
||||
@@ -954,6 +974,618 @@ class TestStripInjectedContext:
|
||||
assert result["content"] == "hello"
|
||||
|
||||
|
||||
# ─── message max_length validation ───────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_rejects_too_long_message():
|
||||
"""A message exceeding max_length=64_000 must be rejected (422)."""
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "x" * 64_001,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_stream_chat_accepts_exactly_max_length_message(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
):
|
||||
"""A message exactly at max_length=64_000 must be accepted."""
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(0, 0, SubscriptionTier.BASIC),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "x" * 64_000,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ─── list_sessions ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_session_info(session_id: str = "sess-1", title: str | None = "Test"):
|
||||
"""Build a minimal ChatSessionInfo-like mock."""
|
||||
from backend.copilot.model import ChatSessionInfo, ChatSessionMetadata
|
||||
|
||||
return ChatSessionInfo(
|
||||
session_id=session_id,
|
||||
user_id=TEST_USER_ID,
|
||||
title=title,
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
metadata=ChatSessionMetadata(),
|
||||
)
|
||||
|
||||
|
||||
def test_list_sessions_returns_sessions(mocker: pytest_mock.MockerFixture) -> None:
|
||||
"""GET /sessions returns list of sessions with is_processing=False when Redis OK."""
|
||||
session = _make_session_info("sess-abc")
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_user_sessions",
|
||||
new_callable=AsyncMock,
|
||||
return_value=([session], 1),
|
||||
)
|
||||
# Redis pipeline returns "done" (not "running") for this session
|
||||
mock_redis = MagicMock()
|
||||
mock_pipe = MagicMock()
|
||||
mock_pipe.hget = MagicMock(return_value=None)
|
||||
mock_pipe.execute = AsyncMock(return_value=["done"])
|
||||
mock_redis.pipeline = MagicMock(return_value=mock_pipe)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
|
||||
response = client.get("/sessions")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert len(data["sessions"]) == 1
|
||||
assert data["sessions"][0]["id"] == "sess-abc"
|
||||
assert data["sessions"][0]["is_processing"] is False
|
||||
|
||||
|
||||
def test_list_sessions_marks_running_as_processing(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Sessions with Redis status='running' should have is_processing=True."""
|
||||
session = _make_session_info("sess-xyz")
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_user_sessions",
|
||||
new_callable=AsyncMock,
|
||||
return_value=([session], 1),
|
||||
)
|
||||
mock_redis = MagicMock()
|
||||
mock_pipe = MagicMock()
|
||||
mock_pipe.hget = MagicMock(return_value=None)
|
||||
mock_pipe.execute = AsyncMock(return_value=["running"])
|
||||
mock_redis.pipeline = MagicMock(return_value=mock_pipe)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
|
||||
response = client.get("/sessions")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["sessions"][0]["is_processing"] is True
|
||||
|
||||
|
||||
def test_list_sessions_redis_failure_defaults_to_not_processing(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Redis failures must be swallowed and sessions default to is_processing=False."""
|
||||
session = _make_session_info("sess-fallback")
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_user_sessions",
|
||||
new_callable=AsyncMock,
|
||||
return_value=([session], 1),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_redis_async",
|
||||
side_effect=Exception("Redis down"),
|
||||
)
|
||||
|
||||
response = client.get("/sessions")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["sessions"][0]["is_processing"] is False
|
||||
|
||||
|
||||
def test_list_sessions_empty(mocker: pytest_mock.MockerFixture) -> None:
|
||||
"""GET /sessions with no sessions returns empty list without hitting Redis."""
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_user_sessions",
|
||||
new_callable=AsyncMock,
|
||||
return_value=([], 0),
|
||||
)
|
||||
|
||||
response = client.get("/sessions")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 0
|
||||
assert data["sessions"] == []
|
||||
|
||||
|
||||
# ─── delete_session ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_delete_session_success(mocker: pytest_mock.MockerFixture) -> None:
|
||||
"""DELETE /sessions/{id} returns 204 when deleted successfully."""
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.delete_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
)
|
||||
# Patch use_e2b_sandbox env-var to disable E2B so the route skips sandbox cleanup.
|
||||
# Patching the Pydantic property directly doesn't work (Pydantic v2 intercepts
|
||||
# attribute setting on BaseSettings instances and raises AttributeError).
|
||||
mocker.patch.dict("os.environ", {"USE_E2B_SANDBOX": "false"})
|
||||
|
||||
response = client.delete("/sessions/sess-1")
|
||||
|
||||
assert response.status_code == 204
|
||||
|
||||
|
||||
def test_delete_session_not_found(mocker: pytest_mock.MockerFixture) -> None:
|
||||
"""DELETE /sessions/{id} returns 404 when session not found or not owned."""
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.delete_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
)
|
||||
|
||||
response = client.delete("/sessions/sess-missing")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ─── cancel_session_task ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_validate_session(
|
||||
mocker: pytest_mock.MockerFixture, *, session_id: str = "sess-1"
|
||||
):
|
||||
"""Mock _validate_and_get_session to return a dummy session."""
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
dummy = ChatSession.new(TEST_USER_ID, dry_run=False)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=dummy,
|
||||
)
|
||||
|
||||
|
||||
def test_cancel_session_no_active_task(mocker: pytest_mock.MockerFixture) -> None:
|
||||
"""Cancel returns cancelled=True with reason when no stream is active."""
|
||||
_mock_validate_session(mocker)
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_active_session = AsyncMock(return_value=(None, None))
|
||||
mocker.patch("backend.api.features.chat.routes.stream_registry", mock_registry)
|
||||
|
||||
response = client.post("/sessions/sess-1/cancel")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["cancelled"] is True
|
||||
assert data["reason"] == "no_active_session"
|
||||
|
||||
|
||||
def test_cancel_session_enqueues_cancel_and_confirms(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Cancel enqueues cancel task and returns cancelled=True once stream stops."""
|
||||
from backend.copilot.stream_registry import ActiveSession
|
||||
|
||||
_mock_validate_session(mocker)
|
||||
active_session = ActiveSession(
|
||||
session_id="sess-1",
|
||||
user_id=TEST_USER_ID,
|
||||
tool_call_id="chat_stream",
|
||||
tool_name="chat",
|
||||
turn_id="turn-1",
|
||||
status="running",
|
||||
)
|
||||
stopped_session = ActiveSession(
|
||||
session_id="sess-1",
|
||||
user_id=TEST_USER_ID,
|
||||
tool_call_id="chat_stream",
|
||||
tool_name="chat",
|
||||
turn_id="turn-1",
|
||||
status="completed",
|
||||
)
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_active_session = AsyncMock(return_value=(active_session, "1-0"))
|
||||
mock_registry.get_session = AsyncMock(return_value=stopped_session)
|
||||
mocker.patch("backend.api.features.chat.routes.stream_registry", mock_registry)
|
||||
mock_enqueue = mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_cancel_task",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post("/sessions/sess-1/cancel")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["cancelled"] is True
|
||||
mock_enqueue.assert_called_once_with("sess-1")
|
||||
|
||||
|
||||
# ─── session_assign_user ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_session_assign_user(mocker: pytest_mock.MockerFixture) -> None:
|
||||
"""PATCH /sessions/{id}/assign-user calls assign_user_to_session and returns ok."""
|
||||
mock_assign = mocker.patch(
|
||||
"backend.api.features.chat.routes.chat_service.assign_user_to_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.patch("/sessions/sess-1/assign-user")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
mock_assign.assert_called_once_with("sess-1", TEST_USER_ID)
|
||||
|
||||
|
||||
# ─── get_ttl_config ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_ttl_config(mocker: pytest_mock.MockerFixture) -> None:
|
||||
"""GET /config/ttl returns correct TTL values derived from config."""
|
||||
mocker.patch.object(chat_routes.config, "stream_ttl", 300)
|
||||
|
||||
response = client.get("/config/ttl")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["stream_ttl_seconds"] == 300
|
||||
assert data["stream_ttl_ms"] == 300_000
|
||||
|
||||
|
||||
# ─── reset_copilot_usage ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_reset_internals(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
cost: int = 100,
|
||||
enable_credit: bool = True,
|
||||
daily_limit: int = 10_000,
|
||||
weekly_limit: int = 50_000,
|
||||
tier: "SubscriptionTier" = SubscriptionTier.BASIC,
|
||||
daily_used: int = 10_001,
|
||||
weekly_used: int = 1_000,
|
||||
reset_count: int | None = 0,
|
||||
acquire_lock: bool = True,
|
||||
reset_daily: bool = True,
|
||||
remaining_balance: int = 9_000,
|
||||
):
|
||||
"""Set up all dependencies for reset_copilot_usage tests."""
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
|
||||
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", cost)
|
||||
mocker.patch.object(chat_routes.config, "max_daily_resets", 3)
|
||||
mocker.patch.object(chat_routes.settings.config, "enable_credit", enable_credit)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(daily_limit, weekly_limit, tier),
|
||||
)
|
||||
resets_at = datetime.now(UTC) + timedelta(hours=1)
|
||||
status = CoPilotUsageStatus(
|
||||
daily=UsageWindow(used=daily_used, limit=daily_limit, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=weekly_limit, resets_at=resets_at),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=status,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_daily_reset_count",
|
||||
new_callable=AsyncMock,
|
||||
return_value=reset_count,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.acquire_reset_lock",
|
||||
new_callable=AsyncMock,
|
||||
return_value=acquire_lock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.release_reset_lock",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.reset_daily_usage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=reset_daily,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.increment_daily_reset_count",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
mock_credit_model = MagicMock()
|
||||
mock_credit_model.spend_credits = AsyncMock(return_value=remaining_balance)
|
||||
mock_credit_model.top_up_credits = AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_user_credit_model",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
return mock_credit_model
|
||||
|
||||
|
||||
def test_reset_usage_returns_400_when_cost_is_zero(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""POST /usage/reset returns 400 when rate_limit_reset_cost <= 0."""
|
||||
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 0)
|
||||
|
||||
response = client.post("/usage/reset")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "not available" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_reset_usage_returns_400_when_credits_disabled(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""POST /usage/reset returns 400 when credit system is disabled."""
|
||||
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100)
|
||||
mocker.patch.object(chat_routes.settings.config, "enable_credit", False)
|
||||
|
||||
response = client.post("/usage/reset")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "disabled" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_reset_usage_returns_400_when_no_daily_limit(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""POST /usage/reset returns 400 when daily_limit is 0."""
|
||||
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100)
|
||||
mocker.patch.object(chat_routes.settings.config, "enable_credit", True)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(0, 50_000, SubscriptionTier.BASIC),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_daily_reset_count",
|
||||
new_callable=AsyncMock,
|
||||
return_value=0,
|
||||
)
|
||||
|
||||
response = client.post("/usage/reset")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "nothing to reset" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_reset_usage_returns_503_when_redis_unavailable(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""POST /usage/reset returns 503 when Redis is unavailable for reset count."""
|
||||
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100)
|
||||
mocker.patch.object(chat_routes.settings.config, "enable_credit", True)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(10_000, 50_000, SubscriptionTier.BASIC),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_daily_reset_count",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.post("/usage/reset")
|
||||
|
||||
assert response.status_code == 503
|
||||
|
||||
|
||||
def test_reset_usage_returns_429_when_max_resets_reached(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""POST /usage/reset returns 429 when max daily resets exceeded."""
|
||||
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100)
|
||||
mocker.patch.object(chat_routes.config, "max_daily_resets", 2)
|
||||
mocker.patch.object(chat_routes.settings.config, "enable_credit", True)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(10_000, 50_000, SubscriptionTier.BASIC),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_daily_reset_count",
|
||||
new_callable=AsyncMock,
|
||||
return_value=2,
|
||||
)
|
||||
|
||||
response = client.post("/usage/reset")
|
||||
|
||||
assert response.status_code == 429
|
||||
assert "resets" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_reset_usage_returns_429_when_lock_not_acquired(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""POST /usage/reset returns 429 when a concurrent reset is in progress."""
|
||||
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100)
|
||||
mocker.patch.object(chat_routes.config, "max_daily_resets", 3)
|
||||
mocker.patch.object(chat_routes.settings.config, "enable_credit", True)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(10_000, 50_000, SubscriptionTier.BASIC),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_daily_reset_count",
|
||||
new_callable=AsyncMock,
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.acquire_reset_lock",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
)
|
||||
|
||||
response = client.post("/usage/reset")
|
||||
|
||||
assert response.status_code == 429
|
||||
assert "in progress" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_reset_usage_returns_400_when_limit_not_reached(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""POST /usage/reset returns 400 when daily limit has not been reached."""
|
||||
_mock_reset_internals(mocker, daily_used=500, daily_limit=10_000)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.release_reset_lock",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post("/usage/reset")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "not reached" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_reset_usage_returns_400_when_weekly_also_exhausted(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""POST /usage/reset returns 400 when weekly limit is also exhausted."""
|
||||
_mock_reset_internals(
|
||||
mocker,
|
||||
daily_used=10_001,
|
||||
daily_limit=10_000,
|
||||
weekly_used=50_001,
|
||||
weekly_limit=50_000,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.release_reset_lock",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post("/usage/reset")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "weekly" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_reset_usage_returns_402_when_insufficient_credits(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""POST /usage/reset returns 402 when credits are insufficient."""
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
mock_credit = _mock_reset_internals(mocker)
|
||||
mock_credit.spend_credits = AsyncMock(
|
||||
side_effect=InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=TEST_USER_ID,
|
||||
balance=0.0,
|
||||
amount=100.0,
|
||||
)
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.release_reset_lock",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post("/usage/reset")
|
||||
|
||||
assert response.status_code == 402
|
||||
|
||||
|
||||
def test_reset_usage_success(mocker: pytest_mock.MockerFixture) -> None:
|
||||
"""POST /usage/reset returns 200 with updated usage on success."""
|
||||
_mock_reset_internals(mocker, remaining_balance=8_900)
|
||||
|
||||
response = client.post("/usage/reset")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["credits_charged"] == 100
|
||||
assert data["remaining_balance"] == 8_900
|
||||
assert "daily" in data["usage"]
|
||||
assert "weekly" in data["usage"]
|
||||
|
||||
|
||||
def test_reset_usage_refunds_on_redis_failure(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""POST /usage/reset returns 503 and refunds credits when Redis reset fails."""
|
||||
mock_credit = _mock_reset_internals(mocker, reset_daily=False)
|
||||
|
||||
response = client.post("/usage/reset")
|
||||
|
||||
assert response.status_code == 503
|
||||
# Credits should be refunded via top_up_credits
|
||||
mock_credit.top_up_credits.assert_called_once()
|
||||
|
||||
|
||||
# ─── resume_session_stream ───────────────────────────────────────────
|
||||
|
||||
|
||||
def test_resume_session_stream_no_active_session(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""GET /sessions/{id}/stream returns 204 when no active session."""
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_active_session = AsyncMock(return_value=(None, None))
|
||||
mocker.patch("backend.api.features.chat.routes.stream_registry", mock_registry)
|
||||
|
||||
response = client.get("/sessions/sess-1/stream")
|
||||
|
||||
assert response.status_code == 204
|
||||
|
||||
|
||||
def test_resume_session_stream_no_subscriber_queue(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""GET /sessions/{id}/stream returns 204 when subscribe_to_session returns None."""
|
||||
from backend.copilot.stream_registry import ActiveSession
|
||||
|
||||
active_session = ActiveSession(
|
||||
session_id="sess-1",
|
||||
user_id=TEST_USER_ID,
|
||||
tool_call_id="chat_stream",
|
||||
tool_name="chat",
|
||||
turn_id="turn-1",
|
||||
status="running",
|
||||
)
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_active_session = AsyncMock(return_value=(active_session, "1-0"))
|
||||
mock_registry.subscribe_to_session = AsyncMock(return_value=None)
|
||||
mocker.patch("backend.api.features.chat.routes.stream_registry", mock_registry)
|
||||
|
||||
response = client.get("/sessions/sess-1/stream")
|
||||
|
||||
assert response.status_code == 204
|
||||
|
||||
|
||||
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────
|
||||
|
||||
|
||||
@@ -1053,3 +1685,119 @@ def test_get_session_returns_backward_paginated(
|
||||
assert data["oldest_sequence"] == 0
|
||||
assert "forward_paginated" not in data
|
||||
assert "newest_sequence" not in data
|
||||
|
||||
|
||||
# ─── POST /sessions with builder_graph_id (get-or-create) ──────────────
|
||||
|
||||
|
||||
def test_create_session_with_builder_graph_id_uses_get_or_create(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""``POST /sessions`` with ``builder_graph_id`` routes through
|
||||
``get_or_create_builder_session`` and returns a session bound to the graph."""
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
async def _fake_get_or_create(user_id: str, graph_id: str) -> ChatSession:
|
||||
return ChatSession.new(
|
||||
user_id,
|
||||
dry_run=False,
|
||||
builder_graph_id=graph_id,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_builder_session",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=_fake_get_or_create,
|
||||
)
|
||||
|
||||
response = client.post("/sessions", json={"builder_graph_id": "graph-1"})
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["metadata"]["builder_graph_id"] == "graph-1"
|
||||
assert body["metadata"]["dry_run"] is False
|
||||
|
||||
|
||||
def test_create_session_with_builder_graph_id_returns_404_when_not_owned(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""``get_or_create_builder_session`` raises ``NotFoundError`` when the
|
||||
user doesn't own the graph; the route must map that to HTTP 404."""
|
||||
|
||||
async def _fake_get_or_create(user_id: str, graph_id: str):
|
||||
raise NotFoundError(f"Graph {graph_id} not found")
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_builder_session",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=_fake_get_or_create,
|
||||
)
|
||||
|
||||
response = client.post("/sessions", json={"builder_graph_id": "graph-unauthorized"})
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_create_session_without_builder_graph_id_creates_fresh(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""With no ``builder_graph_id`` the endpoint falls through to the
|
||||
default ``create_chat_session`` path — no get-or-create lookup."""
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
gorc = mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_builder_session",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
async def _fake_create(user_id: str, *, dry_run: bool) -> ChatSession:
|
||||
return ChatSession.new(user_id, dry_run=dry_run)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.create_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=_fake_create,
|
||||
)
|
||||
|
||||
response = client.post("/sessions", json={"dry_run": True})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["metadata"]["dry_run"] is True
|
||||
gorc.assert_not_called()
|
||||
|
||||
|
||||
def test_create_session_rejects_unknown_fields(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Extra request fields are rejected (422) to prevent silent mis-use."""
|
||||
response = client.post("/sessions", json={"unexpected": "x"})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_resolve_session_permissions_blocks_out_of_scope_tools() -> None:
|
||||
"""Builder-bound sessions return a blacklist of the three tools that
|
||||
conflict with the panel's graph-bound scope. Regular sessions return
|
||||
``None`` so default (unrestricted) behaviour is preserved."""
|
||||
from backend.copilot.builder_context import BUILDER_BLOCKED_TOOLS
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
unbound = ChatSession.new("u1", dry_run=False)
|
||||
assert chat_routes.resolve_session_permissions(unbound) is None
|
||||
|
||||
bound = ChatSession.new("u1", dry_run=False, builder_graph_id="g1")
|
||||
perms = chat_routes.resolve_session_permissions(bound)
|
||||
assert perms is not None
|
||||
assert perms.tools_exclude is True # blacklist, not whitelist
|
||||
assert sorted(perms.tools) == sorted(BUILDER_BLOCKED_TOOLS)
|
||||
# Read-side lookups stay available — only write-scope / guide-dup are blocked.
|
||||
assert "find_block" not in perms.tools
|
||||
assert "find_agent" not in perms.tools
|
||||
assert "search_docs" not in perms.tools
|
||||
# The write tools (edit_agent / run_agent) are NOT blacklisted — they
|
||||
# enforce scope per-tool via the builder_graph_id guard.
|
||||
assert "edit_agent" not in perms.tools
|
||||
assert "run_agent" not in perms.tools
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,7 +14,7 @@ from fastapi import (
|
||||
Security,
|
||||
status,
|
||||
)
|
||||
from pydantic import BaseModel, Field, SecretStr, model_validator
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
||||
|
||||
from backend.api.features.library.db import set_preset_webhook, update_preset
|
||||
@@ -29,15 +29,14 @@ from backend.data.integrations import (
|
||||
wait_for_webhook_event,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserIntegrations,
|
||||
is_sdk_default,
|
||||
)
|
||||
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
from backend.integrations.credentials_store import (
|
||||
@@ -48,7 +47,14 @@ from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
create_mcp_oauth_handler,
|
||||
)
|
||||
from backend.integrations.managed_credentials import ensure_managed_credentials
|
||||
from backend.integrations.managed_credentials import (
|
||||
ensure_managed_credential,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
from backend.integrations.managed_providers.ayrshare import AyrshareManagedProvider
|
||||
from backend.integrations.managed_providers.ayrshare import (
|
||||
settings_available as ayrshare_settings_available,
|
||||
)
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
@@ -87,14 +93,23 @@ async def login(
|
||||
scopes: Annotated[
|
||||
str, Query(title="Comma-separated list of authorization scopes")
|
||||
] = "",
|
||||
credential_id: Annotated[
|
||||
str | None,
|
||||
Query(title="ID of existing credential to upgrade scopes for"),
|
||||
] = None,
|
||||
) -> LoginResponse:
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
|
||||
requested_scopes = scopes.split(",") if scopes else []
|
||||
|
||||
if credential_id:
|
||||
requested_scopes = await _prepare_scope_upgrade(
|
||||
user_id, provider, credential_id, requested_scopes
|
||||
)
|
||||
|
||||
# Generate and store a secure random state token along with the scopes
|
||||
state_token, code_challenge = await creds_manager.store.store_state_token(
|
||||
user_id, provider, requested_scopes
|
||||
user_id, provider, requested_scopes, credential_id=credential_id
|
||||
)
|
||||
login_url = handler.get_login_url(
|
||||
requested_scopes, state_token, code_challenge=code_challenge
|
||||
@@ -216,7 +231,9 @@ async def callback(
|
||||
)
|
||||
|
||||
# TODO: Allow specifying `title` to set on `credentials`
|
||||
await creds_manager.create(user_id, credentials)
|
||||
credentials = await _merge_or_create_credential(
|
||||
user_id, provider, credentials, valid_state.credential_id
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully processed OAuth callback for user {user_id} "
|
||||
@@ -226,13 +243,38 @@ async def callback(
|
||||
return to_meta_response(credentials)
|
||||
|
||||
|
||||
# Bound the first-time sweep so a slow upstream (e.g. Ayrshare) can't hang
|
||||
# the credential-list endpoint. On timeout we still kick off a fire-and-
|
||||
# forget sweep so provisioning eventually completes; the user just won't
|
||||
# see the managed cred until the next refresh.
|
||||
_MANAGED_PROVISION_TIMEOUT_S = 10.0
|
||||
|
||||
|
||||
async def _ensure_managed_credentials_bounded(user_id: str) -> None:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
ensure_managed_credentials(user_id, creds_manager.store),
|
||||
timeout=_MANAGED_PROVISION_TIMEOUT_S,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Managed credential sweep exceeded %.1fs for user=%s; "
|
||||
"continuing without it — provisioning will complete in background",
|
||||
_MANAGED_PROVISION_TIMEOUT_S,
|
||||
user_id,
|
||||
)
|
||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||
|
||||
|
||||
@router.get("/credentials", summary="List Credentials")
|
||||
async def list_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
# Fire-and-forget: provision missing managed credentials in the background.
|
||||
# The credential appears on the next page load; listing is never blocked.
|
||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||
# Block on provisioning so managed credentials appear on the first load
|
||||
# instead of after a refresh, but with a timeout so a slow upstream
|
||||
# can't hang the endpoint. `_provisioned_users` short-circuits on
|
||||
# repeat calls.
|
||||
await _ensure_managed_credentials_bounded(user_id)
|
||||
credentials = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
return [
|
||||
@@ -247,7 +289,7 @@ async def list_credentials_by_provider(
|
||||
],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||
await _ensure_managed_credentials_bounded(user_id)
|
||||
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
|
||||
return [
|
||||
@@ -281,6 +323,115 @@ async def get_credential(
|
||||
return to_meta_response(credential)
|
||||
|
||||
|
||||
class PickerTokenResponse(BaseModel):
|
||||
"""Short-lived OAuth access token shipped to the browser for rendering a
|
||||
provider-hosted picker UI (e.g. Google Drive Picker). Deliberately narrow:
|
||||
only the fields the client needs to initialize the picker widget. Issued
|
||||
from the user's own stored credential so ownership and scope gating are
|
||||
enforced by the credential lookup."""
|
||||
|
||||
access_token: str = Field(
|
||||
description="OAuth access token suitable for the picker SDK call."
|
||||
)
|
||||
access_token_expires_at: int | None = Field(
|
||||
default=None,
|
||||
description="Unix timestamp at which the access token expires, if known.",
|
||||
)
|
||||
|
||||
|
||||
# Allowlist of (provider, scopes) tuples that may mint picker tokens. Only
|
||||
# Drive-picker-capable scopes qualify so a caller can't use this endpoint to
|
||||
# extract a GitHub / other-provider OAuth token for unrelated purposes. If a
|
||||
# future provider integrates a hosted picker that needs a raw access token,
|
||||
# add its specific picker-relevant scopes here.
|
||||
_PICKER_TOKEN_ALLOWED_SCOPES: dict[ProviderName, frozenset[str]] = {
|
||||
ProviderName.GOOGLE: frozenset(
|
||||
[
|
||||
"https://www.googleapis.com/auth/drive.file",
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/drive",
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{provider}/credentials/{cred_id}/picker-token",
|
||||
summary="Issue a short-lived access token for a provider-hosted picker",
|
||||
operation_id="postV1GetPickerToken",
|
||||
)
|
||||
async def get_picker_token(
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider that owns the credentials")
|
||||
],
|
||||
cred_id: Annotated[
|
||||
str, Path(title="The ID of the OAuth2 credentials to mint a token from")
|
||||
],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> PickerTokenResponse:
|
||||
"""Return the raw access token for an OAuth2 credential so the frontend
|
||||
can initialize a provider-hosted picker (e.g. Google Drive Picker).
|
||||
|
||||
`GET /{provider}/credentials/{cred_id}` deliberately strips secrets (see
|
||||
`CredentialsMetaResponse` + `TestGetCredentialReturnsMetaOnly` in
|
||||
`router_test.py`). That hardening broke the Drive picker, which needs the
|
||||
raw access token to call `google.picker.Builder.setOAuthToken(...)`. This
|
||||
endpoint carves a narrow, explicit hole: the caller must own the
|
||||
credential, it must be OAuth2, and the endpoint returns only the access
|
||||
token + its expiry — nothing else about the credential. SDK-default
|
||||
credentials are excluded for the same reason as `get_credential`.
|
||||
"""
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
|
||||
credential = await creds_manager.get(user_id, cred_id)
|
||||
if not credential:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if not provider_matches(credential.provider, provider):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if not isinstance(credential, OAuth2Credentials):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Picker tokens are only available for OAuth2 credentials",
|
||||
)
|
||||
if not credential.access_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Credential has no access token; reconnect the account",
|
||||
)
|
||||
|
||||
# Gate on provider+scope: only credentials that actually grant access to
|
||||
# a provider-hosted picker flow may mint a token through this endpoint.
|
||||
# Prevents using this path to extract bearer tokens for unrelated OAuth
|
||||
# integrations (e.g. GitHub) that happen to be stored under the same user.
|
||||
allowed_scopes = _PICKER_TOKEN_ALLOWED_SCOPES.get(provider)
|
||||
if not allowed_scopes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(f"Picker tokens are not available for provider '{provider.value}'"),
|
||||
)
|
||||
cred_scopes = set(credential.scopes or [])
|
||||
if cred_scopes.isdisjoint(allowed_scopes):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
"Credential does not grant any scope eligible for the picker. "
|
||||
"Reconnect with the appropriate scope."
|
||||
),
|
||||
)
|
||||
|
||||
return PickerTokenResponse(
|
||||
access_token=credential.access_token.get_secret_value(),
|
||||
access_token_expires_at=credential.access_token_expires_at,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{provider}/credentials", status_code=201, summary="Create Credentials")
|
||||
async def create_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
@@ -574,6 +725,186 @@ async def _execute_webhook_preset_trigger(
|
||||
# Continue processing - webhook should be resilient to individual failures
|
||||
|
||||
|
||||
# -------------------- INCREMENTAL AUTH HELPERS -------------------- #
|
||||
|
||||
|
||||
async def _prepare_scope_upgrade(
|
||||
user_id: str,
|
||||
provider: ProviderName,
|
||||
credential_id: str,
|
||||
requested_scopes: list[str],
|
||||
) -> list[str]:
|
||||
"""Validate an existing credential for scope upgrade and compute scopes.
|
||||
|
||||
For providers without native incremental auth (e.g. GitHub), returns the
|
||||
union of existing + requested scopes. For providers that handle merging
|
||||
server-side (e.g. Google with ``include_granted_scopes``), returns the
|
||||
requested scopes unchanged.
|
||||
|
||||
Raises HTTPException on validation failure.
|
||||
"""
|
||||
# Platform-owned system credentials must never be upgraded — scope
|
||||
# changes here would leak across every user that shares them.
|
||||
if is_system_credential(credential_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="System credentials cannot be upgraded",
|
||||
)
|
||||
|
||||
existing = await creds_manager.store.get_creds_by_id(user_id, credential_id)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credential to upgrade not found",
|
||||
)
|
||||
if not isinstance(existing, OAuth2Credentials):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Only OAuth2 credentials can be upgraded",
|
||||
)
|
||||
if not provider_matches(existing.provider, provider.value):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Credential provider does not match the requested provider",
|
||||
)
|
||||
if existing.is_managed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Managed credentials cannot be upgraded",
|
||||
)
|
||||
|
||||
# Google handles scope merging via include_granted_scopes; others need
|
||||
# the union of existing + new scopes in the login URL.
|
||||
if provider != ProviderName.GOOGLE:
|
||||
requested_scopes = list(set(requested_scopes) | set(existing.scopes))
|
||||
|
||||
return requested_scopes
|
||||
|
||||
|
||||
async def _merge_or_create_credential(
|
||||
user_id: str,
|
||||
provider: ProviderName,
|
||||
credentials: OAuth2Credentials,
|
||||
credential_id: str | None,
|
||||
) -> OAuth2Credentials:
|
||||
"""Either upgrade an existing credential or create a new one.
|
||||
|
||||
When *credential_id* is set (explicit upgrade), merges scopes and updates
|
||||
the existing credential. Otherwise, checks for an implicit merge (same
|
||||
provider + username) before falling back to creating a new credential.
|
||||
"""
|
||||
if credential_id:
|
||||
return await _upgrade_existing_credential(user_id, credential_id, credentials)
|
||||
|
||||
# Implicit merge: check for existing credential with same provider+username.
|
||||
# Skip managed/system credentials and require a non-None username on both
|
||||
# sides so we never accidentally merge unrelated credentials.
|
||||
if credentials.username is None:
|
||||
await creds_manager.create(user_id, credentials)
|
||||
return credentials
|
||||
|
||||
existing_creds = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
matching = next(
|
||||
(
|
||||
c
|
||||
for c in existing_creds
|
||||
if isinstance(c, OAuth2Credentials)
|
||||
and not c.is_managed
|
||||
and not is_system_credential(c.id)
|
||||
and c.username is not None
|
||||
and c.username == credentials.username
|
||||
),
|
||||
None,
|
||||
)
|
||||
if matching:
|
||||
# Only merge into the existing credential when the new token
|
||||
# already covers every scope we're about to advertise on it.
|
||||
# Without this guard we'd overwrite ``matching.access_token`` with
|
||||
# a narrower token while storing a wider ``scopes`` list — the
|
||||
# record would claim authorizations the token does not grant, and
|
||||
# blocks using the lost scopes would fail with opaque 401/403s
|
||||
# until the user hits re-auth. On a narrowing login, keep the
|
||||
# two credentials separate instead.
|
||||
if set(credentials.scopes).issuperset(set(matching.scopes)):
|
||||
return await _upgrade_existing_credential(user_id, matching.id, credentials)
|
||||
|
||||
await creds_manager.create(user_id, credentials)
|
||||
return credentials
|
||||
|
||||
|
||||
async def _upgrade_existing_credential(
|
||||
user_id: str,
|
||||
existing_cred_id: str,
|
||||
new_credentials: OAuth2Credentials,
|
||||
) -> OAuth2Credentials:
|
||||
"""Merge scopes from *new_credentials* into an existing credential."""
|
||||
# Defense-in-depth: re-check system and provider invariants right before
|
||||
# the write. The login-time check in `_prepare_scope_upgrade` can go stale
|
||||
# by the time the callback runs, and the implicit-merge path bypasses
|
||||
# login-time validation entirely, so every write-path must enforce these
|
||||
# on its own.
|
||||
if is_system_credential(existing_cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="System credentials cannot be upgraded",
|
||||
)
|
||||
existing = await creds_manager.store.get_creds_by_id(user_id, existing_cred_id)
|
||||
if not existing or not isinstance(existing, OAuth2Credentials):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Credential to upgrade not found",
|
||||
)
|
||||
if existing.is_managed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Managed credentials cannot be upgraded",
|
||||
)
|
||||
if not provider_matches(existing.provider, new_credentials.provider):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Credential provider does not match the requested provider",
|
||||
)
|
||||
|
||||
if (
|
||||
existing.username
|
||||
and new_credentials.username
|
||||
and existing.username != new_credentials.username
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username mismatch: authenticated as a different user",
|
||||
)
|
||||
|
||||
# Operate on a copy so the caller's ``new_credentials`` object is not
|
||||
# mutated out from under them. Every caller today immediately discards
|
||||
# or replaces its reference, but the implicit-merge path in
|
||||
# ``_merge_or_create_credential`` reads ``credentials.scopes`` before
|
||||
# calling into us — a future reader after the call would otherwise
|
||||
# silently see the overwritten values.
|
||||
merged = new_credentials.model_copy(deep=True)
|
||||
merged.id = existing.id
|
||||
merged.title = existing.title
|
||||
merged.scopes = list(set(existing.scopes) | set(new_credentials.scopes))
|
||||
merged.metadata = {
|
||||
**(existing.metadata or {}),
|
||||
**(new_credentials.metadata or {}),
|
||||
}
|
||||
# Preserve the existing refresh_token and username if the incremental
|
||||
# response doesn't carry them. Providers like Google only return a
|
||||
# refresh_token on first authorization — dropping it here would orphan
|
||||
# the credential on the next access-token expiry, forcing the user to
|
||||
# re-auth from scratch. Username is similarly sticky: if we've already
|
||||
# resolved it for this credential, keep it rather than silently
|
||||
# blanking it on an incremental upgrade.
|
||||
if not merged.refresh_token and existing.refresh_token:
|
||||
merged.refresh_token = existing.refresh_token
|
||||
merged.refresh_token_expires_at = existing.refresh_token_expires_at
|
||||
if not merged.username and existing.username:
|
||||
merged.username = existing.username
|
||||
await creds_manager.update(user_id, merged)
|
||||
return merged
|
||||
|
||||
|
||||
# --------------------------- UTILITIES ---------------------------- #
|
||||
|
||||
|
||||
@@ -784,12 +1115,21 @@ def _get_provider_oauth_handler(
|
||||
async def get_ayrshare_sso_url(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> AyrshareSSOResponse:
|
||||
"""
|
||||
Generate an SSO URL for Ayrshare social media integration.
|
||||
"""Generate a JWT SSO URL so the user can link their social accounts.
|
||||
|
||||
Returns:
|
||||
dict: Contains the SSO URL for Ayrshare integration
|
||||
The per-user Ayrshare profile key is provisioned and persisted as a
|
||||
standard ``is_managed=True`` credential by
|
||||
:class:`~backend.integrations.managed_providers.ayrshare.AyrshareManagedProvider`.
|
||||
This endpoint only signs a short-lived JWT pointing at the Ayrshare-
|
||||
hosted social-linking page; all profile lifecycle logic lives with the
|
||||
managed provider.
|
||||
"""
|
||||
if not ayrshare_settings_available():
|
||||
raise HTTPException(
|
||||
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Ayrshare integration is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
client = AyrshareClient()
|
||||
except MissingConfigError:
|
||||
@@ -798,66 +1138,63 @@ async def get_ayrshare_sso_url(
|
||||
detail="Ayrshare integration is not configured",
|
||||
)
|
||||
|
||||
# Ayrshare profile key is stored in the credentials store
|
||||
# It is generated when creating a new profile, if there is no profile key,
|
||||
# we create a new profile and store the profile key in the credentials store
|
||||
|
||||
user_integrations: UserIntegrations = await get_user_integrations(user_id)
|
||||
profile_key = user_integrations.managed_credentials.ayrshare_profile_key
|
||||
|
||||
if not profile_key:
|
||||
logger.debug(f"Creating new Ayrshare profile for user {user_id}")
|
||||
try:
|
||||
profile = await client.create_profile(
|
||||
title=f"User {user_id}", messaging_active=True
|
||||
)
|
||||
profile_key = profile.profileKey
|
||||
await creds_manager.store.set_ayrshare_profile_key(user_id, profile_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating Ayrshare profile for user {user_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_502_BAD_GATEWAY,
|
||||
detail="Failed to create Ayrshare profile",
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Using existing Ayrshare profile for user {user_id}")
|
||||
|
||||
profile_key_str = (
|
||||
profile_key.get_secret_value()
|
||||
if isinstance(profile_key, SecretStr)
|
||||
else str(profile_key)
|
||||
# On-demand provisioning: AyrshareManagedProvider opts out of the
|
||||
# credentials sweep (profile quota is per-user subscription-bound). This
|
||||
# endpoint is the only trigger that provisions a profile — one Ayrshare
|
||||
# profile per user who actually opens the connect flow, not one per
|
||||
# every authenticated user.
|
||||
provisioned = await ensure_managed_credential(
|
||||
user_id, creds_manager.store, AyrshareManagedProvider()
|
||||
)
|
||||
if not provisioned:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_502_BAD_GATEWAY,
|
||||
detail="Failed to provision Ayrshare profile",
|
||||
)
|
||||
|
||||
ayrshare_creds = [
|
||||
c
|
||||
for c in await creds_manager.store.get_creds_by_provider(user_id, "ayrshare")
|
||||
if c.is_managed and isinstance(c, APIKeyCredentials)
|
||||
]
|
||||
if not ayrshare_creds:
|
||||
logger.error(
|
||||
"Ayrshare credential provisioning did not produce a credential "
|
||||
"for user %s",
|
||||
user_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_502_BAD_GATEWAY,
|
||||
detail="Failed to provision Ayrshare profile",
|
||||
)
|
||||
profile_key_str = ayrshare_creds[0].api_key.get_secret_value()
|
||||
|
||||
private_key = settings.secrets.ayrshare_jwt_key
|
||||
# Ayrshare JWT expiry is 2880 minutes (48 hours)
|
||||
# Ayrshare JWT max lifetime is 2880 minutes (48 h).
|
||||
max_expiry_minutes = 2880
|
||||
try:
|
||||
logger.debug(f"Generating Ayrshare JWT for user {user_id}")
|
||||
jwt_response = await client.generate_jwt(
|
||||
private_key=private_key,
|
||||
profile_key=profile_key_str,
|
||||
# `allowed_social` is the set of networks the Ayrshare-hosted
|
||||
# social-linking page will *offer* the user to connect. Blocks
|
||||
# exist for more platforms than are listed here; the list is
|
||||
# deliberately narrower so the rollout can verify each network
|
||||
# end-to-end before widening the user-visible surface. Keep
|
||||
# in sync with tested platforms — extend as each is verified
|
||||
# against the block + Ayrshare's network-specific quirks.
|
||||
allowed_social=[
|
||||
# NOTE: We are enabling platforms one at a time
|
||||
# to speed up the development process
|
||||
# SocialPlatform.FACEBOOK,
|
||||
SocialPlatform.TWITTER,
|
||||
SocialPlatform.LINKEDIN,
|
||||
SocialPlatform.INSTAGRAM,
|
||||
SocialPlatform.YOUTUBE,
|
||||
# SocialPlatform.REDDIT,
|
||||
# SocialPlatform.TELEGRAM,
|
||||
# SocialPlatform.GOOGLE_MY_BUSINESS,
|
||||
# SocialPlatform.PINTEREST,
|
||||
SocialPlatform.TIKTOK,
|
||||
# SocialPlatform.BLUESKY,
|
||||
# SocialPlatform.SNAPCHAT,
|
||||
# SocialPlatform.THREADS,
|
||||
],
|
||||
expires_in=max_expiry_minutes,
|
||||
verify=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating Ayrshare JWT for user {user_id}: {e}")
|
||||
except Exception as exc:
|
||||
logger.error("Error generating Ayrshare JWT for user %s: %s", user_id, exc)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_502_BAD_GATEWAY, detail="Failed to generate JWT"
|
||||
)
|
||||
|
||||
@@ -393,7 +393,7 @@ class TestEnsureManagedCredentials:
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_awaited_once_with("user-1")
|
||||
provider.provision.assert_awaited_once_with("user-1", store)
|
||||
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -568,3 +568,181 @@ class TestCleanupManagedCredentials:
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
# No exception raised — cleanup failure is swallowed.
|
||||
|
||||
|
||||
class TestGetPickerToken:
|
||||
"""POST /{provider}/credentials/{cred_id}/picker-token must:
|
||||
1. Return the access token for OAuth2 creds the caller owns.
|
||||
2. 404 for non-owned, non-existent, or wrong-provider creds.
|
||||
3. 400 for non-OAuth2 creds (API key, host-scoped, user/password).
|
||||
4. 404 for SDK default creds (same hardening as get_credential).
|
||||
5. Preserve the `TestGetCredentialReturnsMetaOnly` contract — the
|
||||
existing meta-only endpoint must still strip secrets even after
|
||||
this picker-token endpoint exists."""
|
||||
|
||||
def test_oauth2_owner_gets_access_token(self):
|
||||
# Use a Google cred with a drive.file scope — only picker-eligible
|
||||
# (provider, scope) pairs can mint a token. GitHub-style creds are
|
||||
# explicitly rejected; see `test_non_picker_provider_rejected_as_400`.
|
||||
cred = _make_oauth2_cred(
|
||||
cred_id="cred-gdrive",
|
||||
provider="google",
|
||||
)
|
||||
cred.scopes = ["https://www.googleapis.com/auth/drive.file"]
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/google/credentials/cred-gdrive/picker-token")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
# The whole point of this endpoint: the access token IS returned here.
|
||||
assert data["access_token"] == "ghp_secret_token"
|
||||
# Only the two declared fields come back — nothing else leaks.
|
||||
assert set(data.keys()) <= {"access_token", "access_token_expires_at"}
|
||||
|
||||
def test_non_picker_provider_rejected_as_400(self):
|
||||
"""Provider allowlist: even with a valid OAuth2 credential, a
|
||||
non-picker provider (GitHub, etc.) cannot mint a picker token.
|
||||
Stops this endpoint from being used as a generic bearer-token
|
||||
extraction path for any stored OAuth cred under the same user."""
|
||||
cred = _make_oauth2_cred(provider="github")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/github/credentials/cred-456/picker-token")
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert "not available for provider" in resp.json()["detail"]
|
||||
assert "ghp_secret_token" not in str(resp.json())
|
||||
|
||||
def test_google_oauth_without_drive_scope_rejected(self):
|
||||
"""Scope allowlist: a Google OAuth2 cred that only carries non-picker
|
||||
scopes (e.g. gmail.readonly, calendar) cannot mint a picker token.
|
||||
Forces the frontend to reconnect with a Drive scope before the
|
||||
picker is available."""
|
||||
cred = _make_oauth2_cred(provider="google")
|
||||
cred.scopes = [
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/calendar",
|
||||
]
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/google/credentials/cred-456/picker-token")
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert "picker" in resp.json()["detail"].lower()
|
||||
|
||||
def test_api_key_credential_rejected_as_400(self):
|
||||
cred = _make_api_key_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/openai/credentials/cred-123/picker-token")
|
||||
|
||||
assert resp.status_code == 400
|
||||
# API keys must not silently fall through to a 200 response of some
|
||||
# other shape — the client should see a clear shape rejection.
|
||||
body = str(resp.json())
|
||||
assert "sk-secret-key-value" not in body
|
||||
|
||||
def test_user_password_credential_rejected_as_400(self):
|
||||
cred = _make_user_password_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/openai/credentials/cred-789/picker-token")
|
||||
|
||||
assert resp.status_code == 400
|
||||
body = str(resp.json())
|
||||
assert "s3cret-pass" not in body
|
||||
assert "admin" not in body
|
||||
|
||||
def test_host_scoped_credential_rejected_as_400(self):
|
||||
cred = _make_host_scoped_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/openai/credentials/cred-host/picker-token")
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert "top-secret" not in str(resp.json())
|
||||
|
||||
def test_missing_credential_returns_404(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=None)
|
||||
resp = client.post("/github/credentials/nonexistent/picker-token")
|
||||
|
||||
assert resp.status_code == 404
|
||||
assert resp.json()["detail"] == "Credentials not found"
|
||||
|
||||
def test_wrong_provider_returns_404(self):
|
||||
"""Symmetric with get_credential: provider mismatch is a generic
|
||||
404, not a 400, so we don't leak existence of a credential the
|
||||
caller doesn't own on that provider."""
|
||||
cred = _make_oauth2_cred(provider="github")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/google/credentials/cred-456/picker-token")
|
||||
|
||||
assert resp.status_code == 404
|
||||
assert resp.json()["detail"] == "Credentials not found"
|
||||
|
||||
def test_sdk_default_returns_404(self):
|
||||
"""SDK defaults are invisible to the user-facing API — picker-token
|
||||
must not mint a token for them either."""
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock()
|
||||
resp = client.post("/openai/credentials/openai-default/picker-token")
|
||||
|
||||
assert resp.status_code == 404
|
||||
mock_mgr.get.assert_not_called()
|
||||
|
||||
def test_oauth2_without_access_token_returns_400(self):
|
||||
"""A stored OAuth2 cred whose access_token is missing can't satisfy
|
||||
a picker init. Surface a clear reconnect instruction rather than
|
||||
returning an empty string."""
|
||||
cred = _make_oauth2_cred()
|
||||
# Simulate a cred that lost its access token
|
||||
object.__setattr__(cred, "access_token", None)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/github/credentials/cred-456/picker-token")
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert "reconnect" in resp.json()["detail"].lower()
|
||||
|
||||
def test_meta_only_endpoint_still_strips_access_token(self):
|
||||
"""Regression guard for the coexistence contract: the new
|
||||
picker-token endpoint must NOT accidentally leak the token through
|
||||
the meta-only GET endpoint. TestGetCredentialReturnsMetaOnly
|
||||
covers this more broadly; this is a fast sanity check co-located
|
||||
with the new endpoint's tests."""
|
||||
cred = _make_oauth2_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/github/credentials/cred-456")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "access_token" not in body
|
||||
assert "refresh_token" not in body
|
||||
assert "ghp_secret_token" not in str(body)
|
||||
|
||||
@@ -743,6 +743,7 @@ async def update_library_agent_version_and_settings(
|
||||
graph=agent_graph,
|
||||
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
||||
builder_chat_session_id=library.settings.builder_chat_session_id,
|
||||
)
|
||||
if updated_settings != library.settings:
|
||||
library = await update_library_agent(
|
||||
@@ -1803,7 +1804,7 @@ async def create_preset_from_graph_execution(
|
||||
raise NotFoundError(
|
||||
f"Graph #{graph_execution.graph_id} not found or accessible"
|
||||
)
|
||||
elif len(graph.aggregate_credentials_inputs()) > 0:
|
||||
elif len(graph.regular_credentials_inputs) > 0:
|
||||
raise ValueError(
|
||||
f"Graph execution #{graph_exec_id} can't be turned into a preset "
|
||||
"because it was run before this feature existed "
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Platform bot linking — user-facing REST routes."""
|
||||
@@ -0,0 +1,158 @@
|
||||
"""User-facing platform_linking REST routes (JWT auth)."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, HTTPException, Path, Security
|
||||
|
||||
from backend.data.db_accessors import platform_linking_db
|
||||
from backend.platform_linking.models import (
|
||||
ConfirmLinkResponse,
|
||||
ConfirmUserLinkResponse,
|
||||
DeleteLinkResponse,
|
||||
LinkTokenInfoResponse,
|
||||
PlatformLinkInfo,
|
||||
PlatformUserLinkInfo,
|
||||
)
|
||||
from backend.util.exceptions import (
|
||||
LinkAlreadyExistsError,
|
||||
LinkFlowMismatchError,
|
||||
LinkTokenExpiredError,
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
TokenPath = Annotated[
|
||||
str,
|
||||
Path(max_length=64, pattern=r"^[A-Za-z0-9_-]+$"),
|
||||
]
|
||||
|
||||
|
||||
def _translate(exc: Exception) -> HTTPException:
|
||||
if isinstance(exc, NotFoundError):
|
||||
return HTTPException(status_code=404, detail=str(exc))
|
||||
if isinstance(exc, NotAuthorizedError):
|
||||
return HTTPException(status_code=403, detail=str(exc))
|
||||
if isinstance(exc, LinkAlreadyExistsError):
|
||||
return HTTPException(status_code=409, detail=str(exc))
|
||||
if isinstance(exc, LinkTokenExpiredError):
|
||||
return HTTPException(status_code=410, detail=str(exc))
|
||||
if isinstance(exc, LinkFlowMismatchError):
|
||||
return HTTPException(status_code=400, detail=str(exc))
|
||||
return HTTPException(status_code=500, detail="Internal error.")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tokens/{token}/info",
|
||||
response_model=LinkTokenInfoResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Get display info for a link token",
|
||||
)
|
||||
async def get_link_token_info_route(token: TokenPath) -> LinkTokenInfoResponse:
|
||||
try:
|
||||
return await platform_linking_db().get_link_token_info(token)
|
||||
except (NotFoundError, LinkTokenExpiredError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tokens/{token}/confirm",
|
||||
response_model=ConfirmLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Confirm a SERVER link token (user must be authenticated)",
|
||||
)
|
||||
async def confirm_link_token(
|
||||
token: TokenPath,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> ConfirmLinkResponse:
|
||||
try:
|
||||
return await platform_linking_db().confirm_server_link(token, user_id)
|
||||
except (
|
||||
NotFoundError,
|
||||
LinkFlowMismatchError,
|
||||
LinkTokenExpiredError,
|
||||
LinkAlreadyExistsError,
|
||||
) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/user-tokens/{token}/confirm",
|
||||
response_model=ConfirmUserLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Confirm a USER link token (user must be authenticated)",
|
||||
)
|
||||
async def confirm_user_link_token(
|
||||
token: TokenPath,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> ConfirmUserLinkResponse:
|
||||
try:
|
||||
return await platform_linking_db().confirm_user_link(token, user_id)
|
||||
except (
|
||||
NotFoundError,
|
||||
LinkFlowMismatchError,
|
||||
LinkTokenExpiredError,
|
||||
LinkAlreadyExistsError,
|
||||
) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/links",
|
||||
response_model=list[PlatformLinkInfo],
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="List all platform servers linked to the authenticated user",
|
||||
)
|
||||
async def list_my_links(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> list[PlatformLinkInfo]:
|
||||
return await platform_linking_db().list_server_links(user_id)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/user-links",
|
||||
response_model=list[PlatformUserLinkInfo],
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="List all DM links for the authenticated user",
|
||||
)
|
||||
async def list_my_user_links(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> list[PlatformUserLinkInfo]:
|
||||
return await platform_linking_db().list_user_links(user_id)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/links/{link_id}",
|
||||
response_model=DeleteLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Unlink a platform server",
|
||||
)
|
||||
async def delete_link(
|
||||
link_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> DeleteLinkResponse:
|
||||
try:
|
||||
return await platform_linking_db().delete_server_link(link_id, user_id)
|
||||
except (NotFoundError, NotAuthorizedError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/user-links/{link_id}",
|
||||
response_model=DeleteLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Unlink a DM / user link",
|
||||
)
|
||||
async def delete_user_link_route(
|
||||
link_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> DeleteLinkResponse:
|
||||
try:
|
||||
return await platform_linking_db().delete_user_link(link_id, user_id)
|
||||
except (NotFoundError, NotAuthorizedError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
@@ -0,0 +1,264 @@
|
||||
"""Route tests: domain exceptions → HTTPException status codes."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from backend.util.exceptions import (
|
||||
LinkAlreadyExistsError,
|
||||
LinkFlowMismatchError,
|
||||
LinkTokenExpiredError,
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
)
|
||||
|
||||
|
||||
def _db_mock(**method_configs):
|
||||
"""Return a mock of the accessor's return value with the given AsyncMocks."""
|
||||
db = MagicMock()
|
||||
for name, mock in method_configs.items():
|
||||
setattr(db, name, mock)
|
||||
return db
|
||||
|
||||
|
||||
class TestTokenInfoRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_maps_to_404(self):
|
||||
from backend.api.features.platform_linking.routes import (
|
||||
get_link_token_info_route,
|
||||
)
|
||||
|
||||
db = _db_mock(
|
||||
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await get_link_token_info_route(token="abc")
|
||||
assert exc.value.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_maps_to_410(self):
|
||||
from backend.api.features.platform_linking.routes import (
|
||||
get_link_token_info_route,
|
||||
)
|
||||
|
||||
db = _db_mock(
|
||||
get_link_token_info=AsyncMock(side_effect=LinkTokenExpiredError("expired"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await get_link_token_info_route(token="abc")
|
||||
assert exc.value.status_code == 410
|
||||
|
||||
|
||||
class TestConfirmLinkRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"exc,expected_status",
|
||||
[
|
||||
(NotFoundError("missing"), 404),
|
||||
(LinkFlowMismatchError("wrong flow"), 400),
|
||||
(LinkTokenExpiredError("expired"), 410),
|
||||
(LinkAlreadyExistsError("already"), 409),
|
||||
],
|
||||
)
|
||||
async def test_translation(self, exc: Exception, expected_status: int):
|
||||
from backend.api.features.platform_linking.routes import confirm_link_token
|
||||
|
||||
db = _db_mock(confirm_server_link=AsyncMock(side_effect=exc))
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as ctx:
|
||||
await confirm_link_token(token="abc", user_id="u1")
|
||||
assert ctx.value.status_code == expected_status
|
||||
|
||||
|
||||
class TestConfirmUserLinkRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"exc,expected_status",
|
||||
[
|
||||
(NotFoundError("missing"), 404),
|
||||
(LinkFlowMismatchError("wrong flow"), 400),
|
||||
(LinkTokenExpiredError("expired"), 410),
|
||||
(LinkAlreadyExistsError("already"), 409),
|
||||
],
|
||||
)
|
||||
async def test_translation(self, exc: Exception, expected_status: int):
|
||||
from backend.api.features.platform_linking.routes import confirm_user_link_token
|
||||
|
||||
db = _db_mock(confirm_user_link=AsyncMock(side_effect=exc))
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as ctx:
|
||||
await confirm_user_link_token(token="abc", user_id="u1")
|
||||
assert ctx.value.status_code == expected_status
|
||||
|
||||
|
||||
class TestDeleteLinkRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_maps_to_404(self):
|
||||
from backend.api.features.platform_linking.routes import delete_link
|
||||
|
||||
db = _db_mock(
|
||||
delete_server_link=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await delete_link(link_id="x", user_id="u1")
|
||||
assert exc.value.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_owned_maps_to_403(self):
|
||||
from backend.api.features.platform_linking.routes import delete_link
|
||||
|
||||
db = _db_mock(
|
||||
delete_server_link=AsyncMock(side_effect=NotAuthorizedError("nope"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await delete_link(link_id="x", user_id="u1")
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
class TestDeleteUserLinkRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_maps_to_404(self):
|
||||
from backend.api.features.platform_linking.routes import delete_user_link_route
|
||||
|
||||
db = _db_mock(delete_user_link=AsyncMock(side_effect=NotFoundError("missing")))
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await delete_user_link_route(link_id="x", user_id="u1")
|
||||
assert exc.value.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_owned_maps_to_403(self):
|
||||
from backend.api.features.platform_linking.routes import delete_user_link_route
|
||||
|
||||
db = _db_mock(
|
||||
delete_user_link=AsyncMock(side_effect=NotAuthorizedError("nope"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await delete_user_link_route(link_id="x", user_id="u1")
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
# ── Adversarial: malformed token path params ──────────────────────────
|
||||
|
||||
|
||||
class TestAdversarialTokenPath:
|
||||
# TokenPath enforces `^[A-Za-z0-9_-]+$` + max_length=64.
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
import fastapi
|
||||
from autogpt_libs.auth import get_user_id, requires_user
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import backend.api.features.platform_linking.routes as routes_mod
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.dependency_overrides[requires_user] = lambda: None
|
||||
app.dependency_overrides[get_user_id] = lambda: "caller-user"
|
||||
app.include_router(routes_mod.router, prefix="/api/platform-linking")
|
||||
return TestClient(app)
|
||||
|
||||
def test_rejects_token_with_special_chars(self, client):
|
||||
response = client.get("/api/platform-linking/tokens/bad%24token/info")
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_rejects_token_with_path_traversal(self, client):
|
||||
for probe in ("..%2F..", "foo..bar", "foo%2Fbar"):
|
||||
response = client.get(f"/api/platform-linking/tokens/{probe}/info")
|
||||
assert response.status_code in (
|
||||
404,
|
||||
422,
|
||||
), f"path-traversal probe {probe!r} returned {response.status_code}"
|
||||
|
||||
def test_rejects_token_too_long(self, client):
|
||||
long_token = "a" * 65
|
||||
response = client.get(f"/api/platform-linking/tokens/{long_token}/info")
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_accepts_token_at_max_length(self, client):
|
||||
token = "a" * 64
|
||||
db = _db_mock(
|
||||
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
response = client.get(f"/api/platform-linking/tokens/{token}/info")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_accepts_urlsafe_b64_token_shape(self, client):
|
||||
db = _db_mock(
|
||||
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
response = client.get("/api/platform-linking/tokens/abc-_XYZ123-_abc/info")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_confirm_rejects_malformed_token(self, client):
|
||||
response = client.post("/api/platform-linking/tokens/bad%24token/confirm")
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestAdversarialDeleteLinkId:
|
||||
"""DELETE link_id has no regex — ensure weird values are handled via
|
||||
NotFoundError (no crash, no cross-user leak)."""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
import fastapi
|
||||
from autogpt_libs.auth import get_user_id, requires_user
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import backend.api.features.platform_linking.routes as routes_mod
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.dependency_overrides[requires_user] = lambda: None
|
||||
app.dependency_overrides[get_user_id] = lambda: "caller-user"
|
||||
app.include_router(routes_mod.router, prefix="/api/platform-linking")
|
||||
return TestClient(app)
|
||||
|
||||
def test_weird_link_id_returns_404(self, client):
|
||||
db = _db_mock(
|
||||
delete_server_link=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
for link_id in ("'; DROP TABLE links;--", "../../etc/passwd", ""):
|
||||
response = client.delete(f"/api/platform-linking/links/{link_id}")
|
||||
assert response.status_code in (404, 405)
|
||||
@@ -189,7 +189,7 @@ async def test_create_store_submission(mocker):
|
||||
notifyOnAgentApproved=True,
|
||||
notifyOnAgentRejected=True,
|
||||
timezone="Europe/Delft",
|
||||
subscriptionTier=prisma.enums.SubscriptionTier.FREE, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
|
||||
subscriptionTier=prisma.enums.SubscriptionTier.BASIC, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
|
||||
)
|
||||
mock_agent = prisma.models.AgentGraph(
|
||||
id="agent-id",
|
||||
|
||||
@@ -47,6 +47,62 @@ def _configure_frontend_origin(mocker: pytest_mock.MockFixture) -> None:
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_pending_subscription_change(mocker: pytest_mock.MockFixture) -> None:
|
||||
"""Default pending-change lookup to None so tests don't hit Stripe/DB.
|
||||
|
||||
Individual tests can override via their own mocker.patch call.
|
||||
"""
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_pending_subscription_change",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
|
||||
_DEFAULT_TIER_PRICES: dict[SubscriptionTier, str | None] = {
|
||||
SubscriptionTier.BASIC: None, # Legacy: stripe-price-id-basic unset by default.
|
||||
SubscriptionTier.PRO: "price_pro",
|
||||
SubscriptionTier.MAX: "price_max",
|
||||
SubscriptionTier.BUSINESS: None, # Reserved: Business card hidden by default.
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_subscription_status_lookups(mocker: pytest_mock.MockFixture) -> None:
|
||||
"""Stub Stripe price + proration + tier-multiplier lookups used by
|
||||
get_subscription_status.
|
||||
|
||||
The POST /credits/subscription handler now returns the full subscription
|
||||
status payload from every branch (same-tier, BASIC downgrade, paid→paid
|
||||
modify, checkout creation), so every POST test implicitly hits these
|
||||
helpers. Individual tests can override via their own mocker.patch call.
|
||||
"""
|
||||
|
||||
async def default_price_id(tier: SubscriptionTier) -> str | None:
|
||||
return _DEFAULT_TIER_PRICES.get(tier)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_subscription_price_id",
|
||||
side_effect=default_price_id,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_proration_credit_cents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=0,
|
||||
)
|
||||
# Default tier-multiplier resolver to the backend defaults so the endpoint
|
||||
# never reaches LaunchDarkly during tests. Individual tests override for
|
||||
# LD-override scenarios.
|
||||
from backend.copilot.rate_limit import _DEFAULT_TIER_MULTIPLIERS
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_tier_multipliers",
|
||||
new_callable=AsyncMock,
|
||||
return_value=dict(_DEFAULT_TIER_MULTIPLIERS),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url,expected",
|
||||
[
|
||||
@@ -88,15 +144,28 @@ def test_get_subscription_status_pro(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""GET /credits/subscription returns PRO tier with Stripe price for a PRO user."""
|
||||
"""GET /credits/subscription returns PRO tier with Stripe prices for all priced tiers."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
prices = {
|
||||
SubscriptionTier.BASIC: "price_basic",
|
||||
SubscriptionTier.PRO: "price_pro",
|
||||
SubscriptionTier.MAX: "price_max",
|
||||
SubscriptionTier.BUSINESS: "price_business",
|
||||
}
|
||||
amounts = {
|
||||
"price_basic": 0,
|
||||
"price_pro": 1999,
|
||||
"price_max": 4999,
|
||||
"price_business": 14999,
|
||||
}
|
||||
|
||||
async def mock_price_id(tier: SubscriptionTier) -> str | None:
|
||||
return "price_pro" if tier == SubscriptionTier.PRO else None
|
||||
return prices.get(tier)
|
||||
|
||||
async def mock_stripe_price_amount(price_id: str) -> int:
|
||||
return 1999 if price_id == "price_pro" else 0
|
||||
return amounts.get(price_id, 0)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
@@ -124,16 +193,63 @@ def test_get_subscription_status_pro(
|
||||
assert data["tier"] == "PRO"
|
||||
assert data["monthly_cost"] == 1999
|
||||
assert data["tier_costs"]["PRO"] == 1999
|
||||
assert data["tier_costs"]["BUSINESS"] == 0
|
||||
assert data["tier_costs"]["FREE"] == 0
|
||||
assert data["tier_costs"]["MAX"] == 4999
|
||||
assert data["tier_costs"]["BUSINESS"] == 14999
|
||||
assert data["tier_costs"]["BASIC"] == 0
|
||||
assert "ENTERPRISE" not in data["tier_costs"]
|
||||
assert data["proration_credit_cents"] == 500
|
||||
# tier_multipliers mirrors the same set of tiers that land in tier_costs,
|
||||
# so the frontend never renders a multiplier badge for a hidden row.
|
||||
assert set(data["tier_multipliers"].keys()) == set(data["tier_costs"].keys())
|
||||
assert data["tier_multipliers"]["BASIC"] == 1.0
|
||||
assert data["tier_multipliers"]["PRO"] == 5.0
|
||||
assert data["tier_multipliers"]["MAX"] == 20.0
|
||||
assert data["tier_multipliers"]["BUSINESS"] == 60.0
|
||||
|
||||
|
||||
def test_get_subscription_status_defaults_to_free(
|
||||
def test_get_subscription_status_tier_multipliers_ld_override(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""GET /credits/subscription when subscription_tier is None defaults to FREE."""
|
||||
"""A LaunchDarkly-overridden tier multiplier flows through the response."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.BASIC
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
# LD says PRO is 7.5× (instead of the 5× default); other tiers unchanged.
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_tier_multipliers",
|
||||
new_callable=AsyncMock,
|
||||
return_value={
|
||||
SubscriptionTier.BASIC: 1.0,
|
||||
SubscriptionTier.PRO: 7.5,
|
||||
SubscriptionTier.MAX: 20.0,
|
||||
SubscriptionTier.BUSINESS: 60.0,
|
||||
SubscriptionTier.ENTERPRISE: 60.0,
|
||||
},
|
||||
)
|
||||
|
||||
response = client.get("/credits/subscription")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Only tiers that made it into tier_costs get a multiplier (default stub
|
||||
# exposes PRO + MAX via _DEFAULT_TIER_PRICES).
|
||||
assert data["tier_multipliers"]["PRO"] == 7.5
|
||||
assert data["tier_multipliers"]["MAX"] == 20.0
|
||||
# BUSINESS has no price configured → hidden from both maps.
|
||||
assert "BUSINESS" not in data["tier_multipliers"]
|
||||
|
||||
|
||||
def test_get_subscription_status_defaults_to_basic(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""When all LD price IDs are unset, tier_costs is empty and the caller sees cost=0."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = None
|
||||
|
||||
@@ -157,14 +273,9 @@ def test_get_subscription_status_defaults_to_free(
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["tier"] == SubscriptionTier.FREE.value
|
||||
assert data["tier"] == SubscriptionTier.BASIC.value
|
||||
assert data["monthly_cost"] == 0
|
||||
assert data["tier_costs"] == {
|
||||
"FREE": 0,
|
||||
"PRO": 0,
|
||||
"BUSINESS": 0,
|
||||
"ENTERPRISE": 0,
|
||||
}
|
||||
assert data["tier_costs"] == {}
|
||||
assert data["proration_credit_cents"] == 0
|
||||
|
||||
|
||||
@@ -215,11 +326,11 @@ def test_get_subscription_status_stripe_error_falls_back_to_zero(
|
||||
assert data["tier_costs"]["PRO"] == 0
|
||||
|
||||
|
||||
def test_update_subscription_tier_free_no_payment(
|
||||
def test_update_subscription_tier_basic_no_payment(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""POST /credits/subscription to FREE tier when payment disabled skips Stripe."""
|
||||
"""POST /credits/subscription to BASIC tier when payment disabled skips Stripe."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
@@ -240,7 +351,7 @@ def test_update_subscription_tier_free_no_payment(
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "FREE"})
|
||||
response = client.post("/credits/subscription", json={"tier": "BASIC"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == ""
|
||||
@@ -252,7 +363,7 @@ def test_update_subscription_tier_paid_beta_user(
|
||||
) -> None:
|
||||
"""POST /credits/subscription for paid tier when payment disabled returns 422."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.FREE
|
||||
mock_user.subscription_tier = SubscriptionTier.BASIC
|
||||
|
||||
async def mock_feature_disabled(*args, **kwargs):
|
||||
return False
|
||||
@@ -279,7 +390,7 @@ def test_update_subscription_tier_paid_requires_urls(
|
||||
) -> None:
|
||||
"""POST /credits/subscription for paid tier without success/cancel URLs returns 422."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.FREE
|
||||
mock_user.subscription_tier = SubscriptionTier.BASIC
|
||||
|
||||
async def mock_feature_enabled(*args, **kwargs):
|
||||
return True
|
||||
@@ -305,7 +416,7 @@ def test_update_subscription_tier_creates_checkout(
|
||||
) -> None:
|
||||
"""POST /credits/subscription creates Stripe Checkout Session for paid upgrade."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.FREE
|
||||
mock_user.subscription_tier = SubscriptionTier.BASIC
|
||||
|
||||
async def mock_feature_enabled(*args, **kwargs):
|
||||
return True
|
||||
@@ -344,7 +455,7 @@ def test_update_subscription_tier_rejects_open_redirect(
|
||||
) -> None:
|
||||
"""POST /credits/subscription rejects success/cancel URLs outside the frontend origin."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.FREE
|
||||
mock_user.subscription_tier = SubscriptionTier.BASIC
|
||||
|
||||
async def mock_feature_enabled(*args, **kwargs):
|
||||
return True
|
||||
@@ -407,30 +518,77 @@ def test_update_subscription_tier_enterprise_blocked(
|
||||
set_tier_mock.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_subscription_tier_same_tier_is_noop(
|
||||
def test_update_subscription_tier_same_tier_releases_pending_change(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""POST /credits/subscription for the user's current paid tier returns 200 with empty URL.
|
||||
"""POST /credits/subscription for the user's current tier releases any pending change.
|
||||
|
||||
Without this guard a duplicate POST (double-click, browser retry, stale page) would
|
||||
create a second Stripe Checkout Session for the same price, potentially billing the
|
||||
user twice until the webhook reconciliation fires.
|
||||
"Stay on my current tier" — the collapsed replacement for the old
|
||||
/credits/subscription/cancel-pending route. Always calls
|
||||
release_pending_subscription_schedule (idempotent when nothing is pending)
|
||||
and returns the refreshed status with url="". Never creates a Checkout
|
||||
Session — that would double-charge a user who double-clicks their own tier.
|
||||
"""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
async def mock_feature_enabled(*args, **kwargs):
|
||||
return True
|
||||
mock_user.subscription_tier = SubscriptionTier.BUSINESS
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
release_mock = mocker.patch(
|
||||
"backend.api.features.v1.release_pending_subscription_schedule",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
)
|
||||
checkout_mock = mocker.patch(
|
||||
"backend.api.features.v1.create_subscription_checkout",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
feature_mock = mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_enabled,
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/subscription",
|
||||
json={
|
||||
"tier": "BUSINESS",
|
||||
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
|
||||
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["tier"] == "BUSINESS"
|
||||
assert data["url"] == ""
|
||||
release_mock.assert_awaited_once_with(TEST_USER_ID)
|
||||
checkout_mock.assert_not_awaited()
|
||||
# Same-tier branch short-circuits before the payment-flag check.
|
||||
feature_mock.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_subscription_tier_same_tier_no_pending_change_returns_status(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Same-tier request when nothing is pending still returns status with url=""."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
release_mock = mocker.patch(
|
||||
"backend.api.features.v1.release_pending_subscription_schedule",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
)
|
||||
checkout_mock = mocker.patch(
|
||||
"backend.api.features.v1.create_subscription_checkout",
|
||||
@@ -447,18 +605,58 @@ def test_update_subscription_tier_same_tier_is_noop(
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == ""
|
||||
data = response.json()
|
||||
assert data["tier"] == "PRO"
|
||||
assert data["url"] == ""
|
||||
assert data["pending_tier"] is None
|
||||
release_mock.assert_awaited_once_with(TEST_USER_ID)
|
||||
checkout_mock.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_subscription_tier_free_with_payment_schedules_cancel_and_does_not_update_db(
|
||||
def test_update_subscription_tier_same_tier_stripe_error_returns_502(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Downgrading to FREE schedules Stripe cancellation at period end.
|
||||
"""Same-tier request surfaces a 502 when Stripe release fails.
|
||||
|
||||
Carries forward the error contract from the removed
|
||||
/credits/subscription/cancel-pending route so clients keep seeing 502 for
|
||||
transient Stripe failures.
|
||||
"""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.BUSINESS
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.release_pending_subscription_schedule",
|
||||
side_effect=stripe.StripeError("network"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/subscription",
|
||||
json={
|
||||
"tier": "BUSINESS",
|
||||
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
|
||||
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 502
|
||||
assert "contact support" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_update_subscription_tier_basic_with_payment_schedules_cancel_and_does_not_update_db(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Downgrading to BASIC schedules Stripe cancellation at period end.
|
||||
|
||||
The DB tier must NOT be updated immediately — the customer.subscription.deleted
|
||||
webhook fires at period end and downgrades to FREE then.
|
||||
webhook fires at period end and downgrades to BASIC then.
|
||||
"""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
@@ -484,18 +682,18 @@ def test_update_subscription_tier_free_with_payment_schedules_cancel_and_does_no
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "FREE"})
|
||||
response = client.post("/credits/subscription", json={"tier": "BASIC"})
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_cancel.assert_awaited_once()
|
||||
mock_set_tier.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_subscription_tier_free_cancel_failure_returns_502(
|
||||
def test_update_subscription_tier_basic_cancel_failure_returns_502(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Downgrading to FREE returns 502 with a generic error (no Stripe detail leakage)."""
|
||||
"""Downgrading to BASIC returns 502 with a generic error (no Stripe detail leakage)."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
@@ -518,7 +716,7 @@ def test_update_subscription_tier_free_cancel_failure_returns_502(
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "FREE"})
|
||||
response = client.post("/credits/subscription", json={"tier": "BASIC"})
|
||||
|
||||
assert response.status_code == 502
|
||||
detail = response.json()["detail"]
|
||||
@@ -635,6 +833,16 @@ def test_update_subscription_tier_paid_to_paid_modifies_subscription(
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
async def price_id_with_business(tier: SubscriptionTier) -> str | None:
|
||||
return {
|
||||
**_DEFAULT_TIER_PRICES,
|
||||
SubscriptionTier.BUSINESS: "price_business",
|
||||
}.get(tier)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_subscription_price_id",
|
||||
side_effect=price_id_with_business,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
@@ -670,6 +878,49 @@ def test_update_subscription_tier_paid_to_paid_modifies_subscription(
|
||||
checkout_mock.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_subscription_tier_max_checkout(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""POST /credits/subscription from PRO→MAX modifies the existing subscription."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
)
|
||||
modify_mock = mocker.patch(
|
||||
"backend.api.features.v1.modify_stripe_subscription_for_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
)
|
||||
checkout_mock = mocker.patch(
|
||||
"backend.api.features.v1.create_subscription_checkout",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/subscription",
|
||||
json={
|
||||
"tier": "MAX",
|
||||
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
|
||||
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == ""
|
||||
modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.MAX)
|
||||
checkout_mock.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
@@ -683,6 +934,16 @@ def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
async def price_id_with_business(tier: SubscriptionTier) -> str | None:
|
||||
return {
|
||||
**_DEFAULT_TIER_PRICES,
|
||||
SubscriptionTier.BUSINESS: "price_business",
|
||||
}.get(tier)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_subscription_price_id",
|
||||
side_effect=price_id_with_business,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
@@ -725,6 +986,128 @@ def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly
|
||||
checkout_mock.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_subscription_tier_priced_basic_no_sub_falls_through_to_checkout(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Once stripe-price-id-basic is configured, a BASIC user without an active sub
|
||||
must hit Stripe Checkout rather than being silently set_subscription_tier'd."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.BASIC
|
||||
|
||||
async def mock_price_id(tier: SubscriptionTier) -> str | None:
|
||||
return {
|
||||
SubscriptionTier.BASIC: "price_basic",
|
||||
SubscriptionTier.PRO: "price_pro",
|
||||
SubscriptionTier.MAX: "price_max",
|
||||
SubscriptionTier.BUSINESS: "price_business",
|
||||
}.get(tier)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_subscription_price_id",
|
||||
side_effect=mock_price_id,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
)
|
||||
modify_mock = mocker.patch(
|
||||
"backend.api.features.v1.modify_stripe_subscription_for_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
)
|
||||
set_tier_mock = mocker.patch(
|
||||
"backend.api.features.v1.set_subscription_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
checkout_mock = mocker.patch(
|
||||
"backend.api.features.v1.create_subscription_checkout",
|
||||
new_callable=AsyncMock,
|
||||
return_value="https://checkout.stripe.com/pay/cs_test_priced_basic",
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/subscription",
|
||||
json={
|
||||
"tier": "PRO",
|
||||
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
|
||||
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert (
|
||||
response.json()["url"] == "https://checkout.stripe.com/pay/cs_test_priced_basic"
|
||||
)
|
||||
# Priced-BASIC user without an active sub: must NOT silently flip DB tier —
|
||||
# they need to set up payment via Checkout.
|
||||
set_tier_mock.assert_not_awaited()
|
||||
checkout_mock.assert_awaited_once()
|
||||
# modify is still called first; returning False just means "no active sub".
|
||||
modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.PRO)
|
||||
|
||||
|
||||
def test_update_subscription_tier_target_without_ld_price_returns_422(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Paid target with no LD-configured Stripe price must fail fast with 422.
|
||||
|
||||
Matches the UI hiding: if `stripe-price-id-pro` resolves to None we can't
|
||||
start a Checkout Session anyway, and we don't want to surface an opaque
|
||||
Stripe error mid-flow. The handler rejects the request before touching
|
||||
Stripe at all.
|
||||
"""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.BASIC
|
||||
|
||||
async def mock_price_id(tier: SubscriptionTier) -> str | None:
|
||||
return None # Neither BASIC nor PRO have an LD price.
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_subscription_price_id",
|
||||
side_effect=mock_price_id,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
)
|
||||
checkout_mock = mocker.patch(
|
||||
"backend.api.features.v1.create_subscription_checkout",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
modify_mock = mocker.patch(
|
||||
"backend.api.features.v1.modify_stripe_subscription_for_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/subscription",
|
||||
json={
|
||||
"tier": "PRO",
|
||||
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
|
||||
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
assert "not available" in response.json()["detail"].lower()
|
||||
checkout_mock.assert_not_awaited()
|
||||
modify_mock.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_subscription_tier_paid_to_paid_stripe_error_returns_502(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
@@ -733,6 +1116,16 @@ def test_update_subscription_tier_paid_to_paid_stripe_error_returns_502(
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
async def price_id_with_business(tier: SubscriptionTier) -> str | None:
|
||||
return {
|
||||
**_DEFAULT_TIER_PRICES,
|
||||
SubscriptionTier.BUSINESS: "price_business",
|
||||
}.get(tier)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_subscription_price_id",
|
||||
side_effect=price_id_with_business,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
@@ -761,11 +1154,11 @@ def test_update_subscription_tier_paid_to_paid_stripe_error_returns_502(
|
||||
assert response.status_code == 502
|
||||
|
||||
|
||||
def test_update_subscription_tier_free_no_stripe_subscription(
|
||||
def test_update_subscription_tier_basic_no_stripe_subscription(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Downgrading to FREE when no Stripe subscription exists updates DB tier directly.
|
||||
"""Downgrading to BASIC when no Stripe subscription exists updates DB tier directly.
|
||||
|
||||
Admin-granted paid tiers have no associated Stripe subscription. When such a
|
||||
user requests a self-service downgrade, cancel_stripe_subscription returns False
|
||||
@@ -796,10 +1189,214 @@ def test_update_subscription_tier_free_no_stripe_subscription(
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "FREE"})
|
||||
response = client.post("/credits/subscription", json={"tier": "BASIC"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == ""
|
||||
cancel_mock.assert_awaited_once_with(TEST_USER_ID)
|
||||
# DB tier must be updated immediately — no webhook will fire for a missing sub
|
||||
set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.FREE)
|
||||
set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BASIC)
|
||||
|
||||
|
||||
def test_get_subscription_status_includes_pending_tier(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""GET /credits/subscription exposes pending_tier and pending_tier_effective_at."""
|
||||
import datetime as dt
|
||||
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.BUSINESS
|
||||
|
||||
effective_at = dt.datetime(2030, 1, 1, tzinfo=dt.timezone.utc)
|
||||
|
||||
async def mock_price_id(tier: SubscriptionTier) -> str | None:
|
||||
return None
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_subscription_price_id",
|
||||
side_effect=mock_price_id,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_proration_credit_cents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_pending_subscription_change",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(SubscriptionTier.PRO, effective_at),
|
||||
)
|
||||
|
||||
response = client.get("/credits/subscription")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["pending_tier"] == "PRO"
|
||||
assert data["pending_tier_effective_at"] is not None
|
||||
|
||||
|
||||
def test_get_subscription_status_no_pending_tier(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""When no pending change exists the response omits pending_tier."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_subscription_price_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_proration_credit_cents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_pending_subscription_change",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.get("/credits/subscription")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["pending_tier"] is None
|
||||
assert data["pending_tier_effective_at"] is None
|
||||
|
||||
|
||||
def test_update_subscription_tier_downgrade_paid_to_paid_schedules(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""A BUSINESS→PRO downgrade request dispatches to modify_stripe_subscription_for_tier."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.BUSINESS
|
||||
|
||||
async def price_id_with_business(tier: SubscriptionTier) -> str | None:
|
||||
return {
|
||||
**_DEFAULT_TIER_PRICES,
|
||||
SubscriptionTier.BUSINESS: "price_business",
|
||||
}.get(tier)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_subscription_price_id",
|
||||
side_effect=price_id_with_business,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
)
|
||||
modify_mock = mocker.patch(
|
||||
"backend.api.features.v1.modify_stripe_subscription_for_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
)
|
||||
checkout_mock = mocker.patch(
|
||||
"backend.api.features.v1.create_subscription_checkout",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/subscription",
|
||||
json={
|
||||
"tier": "PRO",
|
||||
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
|
||||
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == ""
|
||||
modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.PRO)
|
||||
checkout_mock.assert_not_awaited()
|
||||
|
||||
|
||||
def test_stripe_webhook_dispatches_subscription_schedule_released(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""subscription_schedule.released routes to sync_subscription_schedule_from_stripe."""
|
||||
schedule_obj = {"id": "sub_sched_1", "subscription": "sub_pro"}
|
||||
event = {
|
||||
"type": "subscription_schedule.released",
|
||||
"data": {"object": schedule_obj},
|
||||
}
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.settings.secrets.stripe_webhook_secret",
|
||||
new="whsec_test",
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.stripe.Webhook.construct_event",
|
||||
return_value=event,
|
||||
)
|
||||
sync_mock = mocker.patch(
|
||||
"backend.api.features.v1.sync_subscription_schedule_from_stripe",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/stripe_webhook",
|
||||
content=b"{}",
|
||||
headers={"stripe-signature": "t=1,v1=abc"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
sync_mock.assert_awaited_once_with(schedule_obj)
|
||||
|
||||
|
||||
def test_stripe_webhook_ignores_subscription_schedule_updated(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""subscription_schedule.updated must NOT dispatch: our own
|
||||
SubscriptionSchedule.create/.modify calls fire this event and would
|
||||
otherwise loop redundant traffic through the sync handler. State
|
||||
transitions we care about surface via .released/.completed, and phase
|
||||
advance to a new price is already covered by customer.subscription.updated.
|
||||
"""
|
||||
schedule_obj = {"id": "sub_sched_1", "subscription": "sub_pro"}
|
||||
event = {
|
||||
"type": "subscription_schedule.updated",
|
||||
"data": {"object": schedule_obj},
|
||||
}
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.settings.secrets.stripe_webhook_secret",
|
||||
new="whsec_test",
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.stripe.Webhook.construct_event",
|
||||
return_value=event,
|
||||
)
|
||||
sync_mock = mocker.patch(
|
||||
"backend.api.features.v1.sync_subscription_schedule_from_stripe",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/stripe_webhook",
|
||||
content=b"{}",
|
||||
headers={"stripe-signature": "t=1,v1=abc"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
sync_mock.assert_not_awaited()
|
||||
|
||||
@@ -26,10 +26,11 @@ from fastapi import (
|
||||
)
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from prisma.enums import SubscriptionTier
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
from backend.api.model import (
|
||||
CreateAPIKeyRequest,
|
||||
CreateAPIKeyResponse,
|
||||
@@ -43,26 +44,31 @@ from backend.api.model import (
|
||||
UploadFileResponse,
|
||||
)
|
||||
from backend.blocks import get_block, get_blocks
|
||||
from backend.copilot.rate_limit import get_tier_multipliers
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.auth import api_key as api_key_db
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
PendingChangeUnknown,
|
||||
RefundRequest,
|
||||
TransactionHistory,
|
||||
UserCredit,
|
||||
cancel_stripe_subscription,
|
||||
create_subscription_checkout,
|
||||
get_auto_top_up,
|
||||
get_pending_subscription_change,
|
||||
get_proration_credit_cents,
|
||||
get_subscription_price_id,
|
||||
get_user_credit_model,
|
||||
handle_subscription_payment_failure,
|
||||
modify_stripe_subscription_for_tier,
|
||||
release_pending_subscription_schedule,
|
||||
set_auto_top_up,
|
||||
set_subscription_tier,
|
||||
sync_subscription_from_stripe,
|
||||
sync_subscription_schedule_from_stripe,
|
||||
)
|
||||
from backend.data.graph import GraphSettings
|
||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||
@@ -92,6 +98,7 @@ from backend.data.user import (
|
||||
update_user_notification_preference,
|
||||
update_user_timezone,
|
||||
)
|
||||
from backend.data.workspace import get_workspace_file_by_id
|
||||
from backend.executor import scheduler
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||
@@ -693,20 +700,35 @@ async def get_user_auto_top_up(
|
||||
|
||||
|
||||
class SubscriptionTierRequest(BaseModel):
|
||||
tier: Literal["FREE", "PRO", "BUSINESS"]
|
||||
tier: Literal["BASIC", "PRO", "MAX", "BUSINESS"]
|
||||
success_url: str = ""
|
||||
cancel_url: str = ""
|
||||
|
||||
|
||||
class SubscriptionCheckoutResponse(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
tier: Literal["FREE", "PRO", "BUSINESS", "ENTERPRISE"]
|
||||
tier: Literal["BASIC", "PRO", "MAX", "BUSINESS", "ENTERPRISE"]
|
||||
monthly_cost: int # amount in cents (Stripe convention)
|
||||
tier_costs: dict[str, int] # tier name -> amount in cents
|
||||
tier_multipliers: dict[str, float] = Field(
|
||||
default_factory=dict,
|
||||
description=(
|
||||
"Tier → rate-limit multiplier. Covers the same tiers listed in"
|
||||
" ``tier_costs`` so the frontend can render rate-limit badges"
|
||||
" relative to the lowest visible tier without knowing backend"
|
||||
" defaults."
|
||||
),
|
||||
)
|
||||
proration_credit_cents: int # unused portion of current sub to convert on upgrade
|
||||
pending_tier: Optional[Literal["BASIC", "PRO", "MAX", "BUSINESS"]] = None
|
||||
pending_tier_effective_at: Optional[datetime] = None
|
||||
url: str = Field(
|
||||
default="",
|
||||
description=(
|
||||
"Populated only when POST /credits/subscription starts a Stripe Checkout"
|
||||
" Session (BASIC → paid upgrade). Empty string in all other branches —"
|
||||
" the client redirects to this URL when non-empty."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _validate_checkout_redirect_url(url: str) -> bool:
|
||||
@@ -782,39 +804,80 @@ async def get_subscription_status(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> SubscriptionStatusResponse:
|
||||
user = await get_user_by_id(user_id)
|
||||
tier = user.subscription_tier or SubscriptionTier.FREE
|
||||
tier = user.subscription_tier or SubscriptionTier.BASIC
|
||||
|
||||
paid_tiers = [SubscriptionTier.PRO, SubscriptionTier.BUSINESS]
|
||||
priceable_tiers = [
|
||||
SubscriptionTier.BASIC,
|
||||
SubscriptionTier.PRO,
|
||||
SubscriptionTier.MAX,
|
||||
SubscriptionTier.BUSINESS,
|
||||
]
|
||||
price_ids = await asyncio.gather(
|
||||
*[get_subscription_price_id(t) for t in paid_tiers]
|
||||
*[get_subscription_price_id(t) for t in priceable_tiers]
|
||||
)
|
||||
|
||||
tier_costs: dict[str, int] = {
|
||||
SubscriptionTier.FREE.value: 0,
|
||||
SubscriptionTier.ENTERPRISE.value: 0,
|
||||
}
|
||||
|
||||
async def _cost(pid: str | None) -> int:
|
||||
return (await _get_stripe_price_amount(pid) or 0) if pid else 0
|
||||
|
||||
costs = await asyncio.gather(*[_cost(pid) for pid in price_ids])
|
||||
for t, cost in zip(paid_tiers, costs):
|
||||
tier_costs[t.value] = cost
|
||||
|
||||
tier_costs: dict[str, int] = {}
|
||||
for t, pid, cost in zip(priceable_tiers, price_ids, costs):
|
||||
if pid:
|
||||
tier_costs[t.value] = cost
|
||||
|
||||
# Expose the effective rate-limit multipliers alongside prices so the
|
||||
# frontend can render "Nx rate limits" relative to the lowest visible
|
||||
# tier without hard-coding backend defaults. Only emit entries for tiers
|
||||
# that land in ``tier_costs`` — rows hidden at the price layer must stay
|
||||
# hidden in the multiplier layer too.
|
||||
multipliers = await get_tier_multipliers()
|
||||
tier_multipliers: dict[str, float] = {
|
||||
t.value: multipliers.get(t, 1.0)
|
||||
for t in priceable_tiers
|
||||
if t.value in tier_costs
|
||||
}
|
||||
|
||||
current_monthly_cost = tier_costs.get(tier.value, 0)
|
||||
proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost)
|
||||
|
||||
return SubscriptionStatusResponse(
|
||||
try:
|
||||
pending = await get_pending_subscription_change(user_id)
|
||||
except (stripe.StripeError, PendingChangeUnknown):
|
||||
# Swallow Stripe-side failures (rate limits, transient network) AND
|
||||
# PendingChangeUnknown (LaunchDarkly price-id lookup failed). Both
|
||||
# propagate past the cache so the next request retries fresh instead
|
||||
# of serving a stale None for the TTL window. Let real bugs (KeyError,
|
||||
# AttributeError, etc.) propagate so they surface in Sentry.
|
||||
logger.exception(
|
||||
"get_subscription_status: failed to resolve pending change for user %s",
|
||||
user_id,
|
||||
)
|
||||
pending = None
|
||||
|
||||
response = SubscriptionStatusResponse(
|
||||
tier=tier.value,
|
||||
monthly_cost=current_monthly_cost,
|
||||
tier_costs=tier_costs,
|
||||
tier_multipliers=tier_multipliers,
|
||||
proration_credit_cents=proration_credit,
|
||||
)
|
||||
if pending is not None:
|
||||
pending_tier_enum, pending_effective_at = pending
|
||||
if pending_tier_enum in (
|
||||
SubscriptionTier.BASIC,
|
||||
SubscriptionTier.PRO,
|
||||
SubscriptionTier.MAX,
|
||||
SubscriptionTier.BUSINESS,
|
||||
):
|
||||
response.pending_tier = pending_tier_enum.value
|
||||
response.pending_tier_effective_at = pending_effective_at
|
||||
return response
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/credits/subscription",
|
||||
summary="Start a Stripe Checkout session to upgrade subscription tier",
|
||||
summary="Update subscription tier or start a Stripe Checkout session",
|
||||
operation_id="updateSubscriptionTier",
|
||||
tags=["credits"],
|
||||
dependencies=[Security(requires_user)],
|
||||
@@ -822,38 +885,63 @@ async def get_subscription_status(
|
||||
async def update_subscription_tier(
|
||||
request: SubscriptionTierRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> SubscriptionCheckoutResponse:
|
||||
# Pydantic validates tier is one of FREE/PRO/BUSINESS via Literal type.
|
||||
) -> SubscriptionStatusResponse:
|
||||
# Pydantic validates tier is one of BASIC/PRO/MAX/BUSINESS via Literal type.
|
||||
tier = SubscriptionTier(request.tier)
|
||||
|
||||
# ENTERPRISE tier is admin-managed — block self-service changes from ENTERPRISE users.
|
||||
user = await get_user_by_id(user_id)
|
||||
if (user.subscription_tier or SubscriptionTier.FREE) == SubscriptionTier.ENTERPRISE:
|
||||
if (
|
||||
user.subscription_tier or SubscriptionTier.BASIC
|
||||
) == SubscriptionTier.ENTERPRISE:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="ENTERPRISE subscription changes must be managed by an administrator",
|
||||
)
|
||||
|
||||
# Same-tier request = "stay on my current tier" = cancel any pending
|
||||
# scheduled change (paid→paid downgrade or paid→BASIC cancel). This is the
|
||||
# collapsed behaviour that replaces the old /credits/subscription/cancel-pending
|
||||
# route. Safe when no pending change exists: release_pending_subscription_schedule
|
||||
# returns False and we simply return the current status.
|
||||
if (user.subscription_tier or SubscriptionTier.BASIC) == tier:
|
||||
try:
|
||||
await release_pending_subscription_schedule(user_id)
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error releasing pending subscription change for user %s: %s",
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to cancel the pending subscription change right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
return await get_subscription_status(user_id)
|
||||
|
||||
payment_enabled = await is_feature_enabled(
|
||||
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
|
||||
)
|
||||
|
||||
# Downgrade to FREE: schedule Stripe cancellation at period end so the user
|
||||
# keeps their tier for the time they already paid for. The DB tier is NOT
|
||||
# updated here when a subscription exists — the customer.subscription.deleted
|
||||
# webhook fires at period end and downgrades to FREE then.
|
||||
# Exception: if the user has no active Stripe subscription (e.g. admin-granted
|
||||
# tier), cancel_stripe_subscription returns False and we update the DB tier
|
||||
# immediately since no webhook will ever fire.
|
||||
# When payment is disabled entirely, update the DB tier directly.
|
||||
if tier == SubscriptionTier.FREE:
|
||||
current_tier = user.subscription_tier or SubscriptionTier.BASIC
|
||||
target_price_id, current_tier_price_id = await asyncio.gather(
|
||||
get_subscription_price_id(tier),
|
||||
get_subscription_price_id(current_tier),
|
||||
)
|
||||
|
||||
# Legacy cancel: target BASIC + stripe-price-id-basic unset. Schedule Stripe
|
||||
# cancellation at period end; cancel_at_period_end=True lets the webhook flip
|
||||
# the DB tier. No active sub (admin-granted) or payment disabled → DB flip.
|
||||
# Once stripe-price-id-basic is configured, BASIC becomes a real sub and falls
|
||||
# through to the modify/checkout flow below.
|
||||
if tier == SubscriptionTier.BASIC and target_price_id is None:
|
||||
if payment_enabled:
|
||||
try:
|
||||
had_subscription = await cancel_stripe_subscription(user_id)
|
||||
except stripe.StripeError as e:
|
||||
# Log full Stripe error server-side but return a generic message
|
||||
# to the client — raw Stripe errors can leak customer/sub IDs and
|
||||
# infrastructure config details.
|
||||
logger.exception(
|
||||
"Stripe error cancelling subscription for user %s: %s",
|
||||
user_id,
|
||||
@@ -867,48 +955,37 @@ async def update_subscription_tier(
|
||||
),
|
||||
)
|
||||
if not had_subscription:
|
||||
# No active Stripe subscription found — the user was on an
|
||||
# admin-granted tier. Update DB immediately since the
|
||||
# subscription.deleted webhook will never fire.
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
return await get_subscription_status(user_id)
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
return await get_subscription_status(user_id)
|
||||
|
||||
# Paid tier changes require payment to be enabled — block self-service upgrades
|
||||
# when the flag is off. Admins use the /api/admin/ routes to set tiers directly.
|
||||
if not payment_enabled:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Subscription not available for tier {tier}",
|
||||
detail=f"Subscription not available for tier {tier.value}",
|
||||
)
|
||||
|
||||
# No-op short-circuit: if the user is already on the requested paid tier,
|
||||
# do NOT create a new Checkout Session. Without this guard, a duplicate
|
||||
# request (double-click, retried POST, stale page) creates a second
|
||||
# subscription for the same price; the user would be charged for both
|
||||
# until `_cleanup_stale_subscriptions` runs from the resulting webhook —
|
||||
# which only fires after the second charge has cleared.
|
||||
if (user.subscription_tier or SubscriptionTier.FREE) == tier:
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
# Target has no LD price — not provisionable (matches the GET hiding).
|
||||
if target_price_id is None:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Subscription not available for tier {tier.value}",
|
||||
)
|
||||
|
||||
# Paid→paid tier change: if the user already has a Stripe subscription,
|
||||
# modify it in-place with proration instead of creating a new Checkout
|
||||
# Session. This preserves remaining paid time and avoids double-charging.
|
||||
# The customer.subscription.updated webhook fires and updates the DB tier.
|
||||
current_tier = user.subscription_tier or SubscriptionTier.FREE
|
||||
if current_tier in (SubscriptionTier.PRO, SubscriptionTier.BUSINESS):
|
||||
# User has an active Stripe subscription (current tier has an LD price):
|
||||
# modify it in-place. modify_stripe_subscription_for_tier returns False when no
|
||||
# active sub exists — that's only a "DB-only flip is OK" signal for admin-granted
|
||||
# paid tiers (PRO/BUSINESS with no Stripe record). Priced-BASIC users without a
|
||||
# sub must still go through Checkout so they set up payment.
|
||||
if current_tier_price_id is not None:
|
||||
try:
|
||||
modified = await modify_stripe_subscription_for_tier(user_id, tier)
|
||||
if modified:
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
# modify_stripe_subscription_for_tier returns False when no active
|
||||
# Stripe subscription exists — i.e. the user has an admin-granted
|
||||
# paid tier with no Stripe record. In that case, update the DB
|
||||
# tier directly (same as the FREE-downgrade path for admin-granted
|
||||
# users) rather than sending them through a new Checkout Session.
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
return await get_subscription_status(user_id)
|
||||
if current_tier != SubscriptionTier.BASIC:
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return await get_subscription_status(user_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except stripe.StripeError as e:
|
||||
@@ -923,7 +1000,7 @@ async def update_subscription_tier(
|
||||
),
|
||||
)
|
||||
|
||||
# Paid upgrade from FREE → create Stripe Checkout Session.
|
||||
# No active Stripe subscription → create Stripe Checkout Session.
|
||||
if not request.success_url or not request.cancel_url:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
@@ -978,7 +1055,9 @@ async def update_subscription_tier(
|
||||
),
|
||||
)
|
||||
|
||||
return SubscriptionCheckoutResponse(url=url)
|
||||
status = await get_subscription_status(user_id)
|
||||
status.url = url
|
||||
return status
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -1043,6 +1122,18 @@ async def stripe_webhook(request: Request):
|
||||
):
|
||||
await sync_subscription_from_stripe(data_object)
|
||||
|
||||
# `subscription_schedule.updated` is deliberately omitted: our own
|
||||
# `SubscriptionSchedule.create` + `.modify` calls in
|
||||
# `_schedule_downgrade_at_period_end` would fire that event right back at us
|
||||
# and loop redundant traffic through this handler. We only care about state
|
||||
# transitions (released / completed); phase advance to the new price is
|
||||
# already covered by `customer.subscription.updated`.
|
||||
if event_type in (
|
||||
"subscription_schedule.released",
|
||||
"subscription_schedule.completed",
|
||||
):
|
||||
await sync_subscription_schedule_from_stripe(data_object)
|
||||
|
||||
if event_type == "invoice.payment_failed":
|
||||
await handle_subscription_payment_failure(data_object)
|
||||
|
||||
@@ -1640,6 +1731,10 @@ async def enable_execution_sharing(
|
||||
# Generate a unique share token
|
||||
share_token = str(uuid.uuid4())
|
||||
|
||||
# Remove stale allowlist records before updating the token — prevents a
|
||||
# window where old records + new token could coexist.
|
||||
await execution_db.delete_shared_execution_files(execution_id=graph_exec_id)
|
||||
|
||||
# Update the execution with share info
|
||||
await execution_db.update_graph_execution_share_status(
|
||||
execution_id=graph_exec_id,
|
||||
@@ -1649,6 +1744,14 @@ async def enable_execution_sharing(
|
||||
shared_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Create allowlist of workspace files referenced in outputs
|
||||
await execution_db.create_shared_execution_files(
|
||||
execution_id=graph_exec_id,
|
||||
share_token=share_token,
|
||||
user_id=user_id,
|
||||
outputs=execution.outputs,
|
||||
)
|
||||
|
||||
# Return the share URL
|
||||
frontend_url = settings.config.frontend_base_url or "http://localhost:3000"
|
||||
share_url = f"{frontend_url}/share/{share_token}"
|
||||
@@ -1674,6 +1777,9 @@ async def disable_execution_sharing(
|
||||
if not execution:
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
# Remove shared file allowlist records
|
||||
await execution_db.delete_shared_execution_files(execution_id=graph_exec_id)
|
||||
|
||||
# Remove share info
|
||||
await execution_db.update_graph_execution_share_status(
|
||||
execution_id=graph_exec_id,
|
||||
@@ -1699,6 +1805,43 @@ async def get_shared_execution(
|
||||
return execution
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/public/shared/{share_token}/files/{file_id}/download",
|
||||
summary="Download a file from a shared execution",
|
||||
operation_id="download_shared_file",
|
||||
tags=["graphs"],
|
||||
)
|
||||
async def download_shared_file(
|
||||
share_token: Annotated[
|
||||
str,
|
||||
Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"),
|
||||
],
|
||||
file_id: Annotated[
|
||||
str,
|
||||
Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"),
|
||||
],
|
||||
) -> Response:
|
||||
"""Download a workspace file from a shared execution (no auth required).
|
||||
|
||||
Validates that the file was explicitly exposed when sharing was enabled.
|
||||
Returns a uniform 404 for all failure modes to prevent enumeration attacks.
|
||||
"""
|
||||
# Single-query validation against the allowlist
|
||||
execution_id = await execution_db.get_shared_execution_file(
|
||||
share_token=share_token, file_id=file_id
|
||||
)
|
||||
if not execution_id:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
|
||||
# Look up the actual file (no workspace scoping needed — the allowlist
|
||||
# already validated that this file belongs to the shared execution)
|
||||
file = await get_workspace_file_by_id(file_id)
|
||||
if not file:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
|
||||
return await create_file_download_response(file, inline=True)
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Schedules ########################
|
||||
########################################################
|
||||
|
||||
157
autogpt_platform/backend/backend/api/features/v1_share_test.py
Normal file
157
autogpt_platform/backend/backend/api/features/v1_share_test.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""Tests for the public shared file download endpoint."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.responses import Response
|
||||
|
||||
from backend.api.features.v1 import v1_router
|
||||
from backend.data.workspace import WorkspaceFile
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(v1_router, prefix="/api")
|
||||
|
||||
VALID_TOKEN = "550e8400-e29b-41d4-a716-446655440000"
|
||||
VALID_FILE_ID = "6ba7b810-9dad-11d1-80b4-00c04fd430c8"
|
||||
|
||||
|
||||
def _make_workspace_file(**overrides) -> WorkspaceFile:
|
||||
defaults = {
|
||||
"id": VALID_FILE_ID,
|
||||
"workspace_id": "ws-001",
|
||||
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"name": "image.png",
|
||||
"path": "/image.png",
|
||||
"storage_path": "local://uploads/image.png",
|
||||
"mime_type": "image/png",
|
||||
"size_bytes": 4,
|
||||
"checksum": None,
|
||||
"is_deleted": False,
|
||||
"deleted_at": None,
|
||||
"metadata": {},
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return WorkspaceFile(**defaults)
|
||||
|
||||
|
||||
def _mock_download_response(**kwargs):
|
||||
"""Return an AsyncMock that resolves to a Response with inline disposition."""
|
||||
|
||||
async def _handler(file, *, inline=False):
|
||||
return Response(
|
||||
content=b"\x89PNG",
|
||||
media_type="image/png",
|
||||
headers={
|
||||
"Content-Disposition": (
|
||||
'inline; filename="image.png"'
|
||||
if inline
|
||||
else 'attachment; filename="image.png"'
|
||||
),
|
||||
"Content-Length": "4",
|
||||
},
|
||||
)
|
||||
|
||||
return _handler
|
||||
|
||||
|
||||
class TestDownloadSharedFile:
|
||||
"""Tests for GET /api/public/shared/{token}/files/{id}/download."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _client(self):
|
||||
self.client = TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
def test_valid_token_and_file_returns_inline_content(self):
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value="exec-123",
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.v1.get_workspace_file_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_make_workspace_file(),
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.v1.create_file_download_response",
|
||||
side_effect=_mock_download_response(),
|
||||
),
|
||||
):
|
||||
response = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"\x89PNG"
|
||||
assert "inline" in response.headers["Content-Disposition"]
|
||||
|
||||
def test_invalid_token_format_returns_422(self):
|
||||
response = self.client.get(
|
||||
f"/api/public/shared/not-a-uuid/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_token_not_in_allowlist_returns_404(self):
|
||||
with patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
response = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_file_missing_from_workspace_returns_404(self):
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value="exec-123",
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.v1.get_workspace_file_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_uniform_404_prevents_enumeration(self):
|
||||
"""Both failure modes produce identical 404 — no information leak."""
|
||||
with patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
resp_no_allow = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value="exec-123",
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.v1.get_workspace_file_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
resp_no_file = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
|
||||
assert resp_no_allow.status_code == 404
|
||||
assert resp_no_file.status_code == 404
|
||||
assert resp_no_allow.json() == resp_no_file.json()
|
||||
@@ -31,7 +31,9 @@ from backend.util.workspace import WorkspaceManager
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
|
||||
def _sanitize_filename_for_header(filename: str) -> str:
|
||||
def _sanitize_filename_for_header(
|
||||
filename: str, disposition: str = "attachment"
|
||||
) -> str:
|
||||
"""
|
||||
Sanitize filename for Content-Disposition header to prevent header injection.
|
||||
|
||||
@@ -46,11 +48,11 @@ def _sanitize_filename_for_header(filename: str) -> str:
|
||||
# Check if filename has non-ASCII characters
|
||||
try:
|
||||
sanitized.encode("ascii")
|
||||
return f'attachment; filename="{sanitized}"'
|
||||
return f'{disposition}; filename="{sanitized}"'
|
||||
except UnicodeEncodeError:
|
||||
# Use RFC5987 encoding for UTF-8 filenames
|
||||
encoded = quote(sanitized, safe="")
|
||||
return f"attachment; filename*=UTF-8''{encoded}"
|
||||
return f"{disposition}; filename*=UTF-8''{encoded}"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -60,19 +62,26 @@ router = fastapi.APIRouter(
|
||||
)
|
||||
|
||||
|
||||
def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
|
||||
def _create_streaming_response(
|
||||
content: bytes, file: WorkspaceFile, *, inline: bool = False
|
||||
) -> Response:
|
||||
"""Create a streaming response for file content."""
|
||||
disposition = _sanitize_filename_for_header(
|
||||
file.name, disposition="inline" if inline else "attachment"
|
||||
)
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=file.mime_type,
|
||||
headers={
|
||||
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
||||
"Content-Disposition": disposition,
|
||||
"Content-Length": str(len(content)),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||
async def create_file_download_response(
|
||||
file: WorkspaceFile, *, inline: bool = False
|
||||
) -> Response:
|
||||
"""
|
||||
Create a download response for a workspace file.
|
||||
|
||||
@@ -84,7 +93,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||
# For local storage, stream the file directly
|
||||
if file.storage_path.startswith("local://"):
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
return _create_streaming_response(content, file)
|
||||
return _create_streaming_response(content, file, inline=inline)
|
||||
|
||||
# For GCS, try to redirect to signed URL, fall back to streaming
|
||||
try:
|
||||
@@ -92,7 +101,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||
# If we got back an API path (fallback), stream directly instead
|
||||
if url.startswith("/api/"):
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
return _create_streaming_response(content, file)
|
||||
return _create_streaming_response(content, file, inline=inline)
|
||||
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
||||
except Exception as e:
|
||||
# Log the signed URL failure with context
|
||||
@@ -104,7 +113,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||
# Fall back to streaming directly from GCS
|
||||
try:
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
return _create_streaming_response(content, file)
|
||||
return _create_streaming_response(content, file, inline=inline)
|
||||
except Exception as fallback_error:
|
||||
logger.error(
|
||||
f"Fallback streaming also failed for file {file.id} "
|
||||
@@ -171,7 +180,7 @@ async def download_file(
|
||||
if file is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return await _create_file_download_response(file)
|
||||
return await create_file_download_response(file)
|
||||
|
||||
|
||||
@router.delete(
|
||||
|
||||
@@ -630,3 +630,221 @@ def test_get_storage_usage_returns_tier_based_limit(mocker):
|
||||
assert data["limit_bytes"] == 1024 * 1024 * 1024
|
||||
assert data["used_bytes"] == 100 * 1024 * 1024
|
||||
assert data["file_count"] == 5
|
||||
|
||||
|
||||
# -- _sanitize_filename_for_header tests --
|
||||
|
||||
|
||||
class TestSanitizeFilenameForHeader:
|
||||
def test_simple_ascii_attachment(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
assert _sanitize_filename_for_header("report.pdf") == (
|
||||
'attachment; filename="report.pdf"'
|
||||
)
|
||||
|
||||
def test_inline_disposition(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
assert _sanitize_filename_for_header("image.png", disposition="inline") == (
|
||||
'inline; filename="image.png"'
|
||||
)
|
||||
|
||||
def test_strips_cr_lf_null(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("a\rb\nc\x00d.txt")
|
||||
assert "\r" not in result
|
||||
assert "\n" not in result
|
||||
assert "\x00" not in result
|
||||
assert 'filename="abcd.txt"' in result
|
||||
|
||||
def test_escapes_quotes(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header('file"name.txt')
|
||||
assert 'filename="file\\"name.txt"' in result
|
||||
|
||||
def test_header_injection_blocked(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("evil.txt\r\nX-Injected: true")
|
||||
# CR/LF stripped — the remaining text is safely inside the quoted value
|
||||
assert "\r" not in result
|
||||
assert "\n" not in result
|
||||
assert result == 'attachment; filename="evil.txtX-Injected: true"'
|
||||
|
||||
def test_unicode_uses_rfc5987(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("日本語.pdf")
|
||||
assert "filename*=UTF-8''" in result
|
||||
assert "attachment" in result
|
||||
|
||||
def test_unicode_inline(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("图片.png", disposition="inline")
|
||||
assert result.startswith("inline; filename*=UTF-8''")
|
||||
|
||||
def test_empty_filename(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("")
|
||||
assert result == 'attachment; filename=""'
|
||||
|
||||
|
||||
# -- _create_streaming_response tests --
|
||||
|
||||
|
||||
class TestCreateStreamingResponse:
|
||||
def test_attachment_disposition_by_default(self):
|
||||
from backend.api.features.workspace.routes import _create_streaming_response
|
||||
|
||||
file = _make_file(name="data.bin", mime_type="application/octet-stream")
|
||||
response = _create_streaming_response(b"binary-data", file)
|
||||
assert (
|
||||
response.headers["Content-Disposition"] == 'attachment; filename="data.bin"'
|
||||
)
|
||||
assert response.headers["Content-Type"] == "application/octet-stream"
|
||||
assert response.headers["Content-Length"] == "11"
|
||||
assert response.body == b"binary-data"
|
||||
|
||||
def test_inline_disposition(self):
|
||||
from backend.api.features.workspace.routes import _create_streaming_response
|
||||
|
||||
file = _make_file(name="photo.png", mime_type="image/png")
|
||||
response = _create_streaming_response(b"\x89PNG", file, inline=True)
|
||||
assert response.headers["Content-Disposition"] == 'inline; filename="photo.png"'
|
||||
assert response.headers["Content-Type"] == "image/png"
|
||||
|
||||
def test_inline_sanitizes_filename(self):
|
||||
from backend.api.features.workspace.routes import _create_streaming_response
|
||||
|
||||
file = _make_file(name='evil"\r\n.txt', mime_type="text/plain")
|
||||
response = _create_streaming_response(b"data", file, inline=True)
|
||||
assert "\r" not in response.headers["Content-Disposition"]
|
||||
assert "\n" not in response.headers["Content-Disposition"]
|
||||
assert "inline" in response.headers["Content-Disposition"]
|
||||
|
||||
def test_content_length_matches_body(self):
|
||||
from backend.api.features.workspace.routes import _create_streaming_response
|
||||
|
||||
content = b"x" * 1000
|
||||
file = _make_file(name="big.bin", mime_type="application/octet-stream")
|
||||
response = _create_streaming_response(content, file)
|
||||
assert response.headers["Content-Length"] == "1000"
|
||||
|
||||
|
||||
# -- create_file_download_response tests --
|
||||
|
||||
|
||||
class TestCreateFileDownloadResponse:
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_storage_returns_streaming_response(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.return_value = b"file contents"
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(
|
||||
storage_path="local://uploads/test.txt",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
response = await create_file_download_response(file)
|
||||
assert response.status_code == 200
|
||||
assert response.body == b"file contents"
|
||||
assert "attachment" in response.headers["Content-Disposition"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_storage_inline(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.return_value = b"\x89PNG"
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(
|
||||
storage_path="local://uploads/photo.png",
|
||||
mime_type="image/png",
|
||||
name="photo.png",
|
||||
)
|
||||
response = await create_file_download_response(file, inline=True)
|
||||
assert "inline" in response.headers["Content-Disposition"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gcs_redirect(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.get_download_url.return_value = (
|
||||
"https://storage.googleapis.com/signed-url"
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(storage_path="gcs://bucket/file.pdf")
|
||||
response = await create_file_download_response(file)
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers["location"] == "https://storage.googleapis.com/signed-url"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gcs_api_fallback_streams_directly(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.get_download_url.return_value = "/api/fallback"
|
||||
mock_storage.retrieve.return_value = b"fallback content"
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(storage_path="gcs://bucket/file.txt")
|
||||
response = await create_file_download_response(file)
|
||||
assert response.status_code == 200
|
||||
assert response.body == b"fallback content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gcs_signed_url_failure_falls_back_to_streaming(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.get_download_url.side_effect = RuntimeError("GCS error")
|
||||
mock_storage.retrieve.return_value = b"streamed"
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(storage_path="gcs://bucket/file.txt")
|
||||
response = await create_file_download_response(file)
|
||||
assert response.status_code == 200
|
||||
assert response.body == b"streamed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gcs_total_failure_raises(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.get_download_url.side_effect = RuntimeError("GCS error")
|
||||
mock_storage.retrieve.side_effect = RuntimeError("Also failed")
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(storage_path="gcs://bucket/file.txt")
|
||||
with pytest.raises(RuntimeError, match="Also failed"):
|
||||
await create_file_download_response(file)
|
||||
|
||||
@@ -17,6 +17,7 @@ from fastapi.routing import APIRoute
|
||||
from prisma.errors import PrismaError
|
||||
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.diagnostics_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.platform_cost_routes
|
||||
import backend.api.features.admin.rate_limit_admin_routes
|
||||
@@ -31,6 +32,7 @@ import backend.api.features.library.routes
|
||||
import backend.api.features.mcp.routes as mcp_routes
|
||||
import backend.api.features.oauth
|
||||
import backend.api.features.otto.routes
|
||||
import backend.api.features.platform_linking.routes
|
||||
import backend.api.features.postmark.postmark
|
||||
import backend.api.features.store.model
|
||||
import backend.api.features.store.routes
|
||||
@@ -320,6 +322,11 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/credits",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.diagnostics_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.execution_analytics_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
@@ -372,6 +379,11 @@ app.include_router(
|
||||
tags=["oauth"],
|
||||
prefix="/api/oauth",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.platform_linking.routes.router,
|
||||
tags=["platform-linking"],
|
||||
prefix="/api/platform-linking",
|
||||
)
|
||||
|
||||
app.mount("/external-api", external_api)
|
||||
|
||||
|
||||
@@ -42,11 +42,13 @@ def main(**kwargs):
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
from backend.executor import ExecutionManager, Scheduler
|
||||
from backend.notifications import NotificationManager
|
||||
from backend.platform_linking.manager import PlatformLinkingManager
|
||||
|
||||
run_processes(
|
||||
DatabaseManager().set_log_level("warning"),
|
||||
Scheduler(),
|
||||
NotificationManager(),
|
||||
PlatformLinkingManager(),
|
||||
WebsocketServer(),
|
||||
AgentServer(),
|
||||
ExecutionManager(),
|
||||
|
||||
@@ -96,27 +96,64 @@ class BlockCategory(Enum):
|
||||
|
||||
|
||||
class BlockCostType(str, Enum):
|
||||
RUN = "run" # cost X credits per run
|
||||
BYTE = "byte" # cost X credits per byte
|
||||
SECOND = "second" # cost X credits per second
|
||||
# RUN : cost_amount credits per run.
|
||||
# BYTE : cost_amount credits per byte of input data.
|
||||
# SECOND : cost_amount credits per cost_divisor walltime seconds.
|
||||
# ITEMS : cost_amount credits per cost_divisor items (from stats).
|
||||
# COST_USD : cost_amount credits per USD of stats.provider_cost.
|
||||
# TOKENS : per-(model, provider) rate table lookup; see TOKEN_COST.
|
||||
RUN = "run"
|
||||
BYTE = "byte"
|
||||
SECOND = "second"
|
||||
ITEMS = "items"
|
||||
COST_USD = "cost_usd"
|
||||
TOKENS = "tokens"
|
||||
|
||||
@property
|
||||
def is_dynamic(self) -> bool:
|
||||
"""Real charge is computed post-flight from stats.
|
||||
|
||||
Dynamic types (SECOND/ITEMS/COST_USD/TOKENS) return 0 pre-flight and
|
||||
settle against stats via charge_reconciled_usage once the block runs.
|
||||
"""
|
||||
return self in _DYNAMIC_COST_TYPES
|
||||
|
||||
|
||||
_DYNAMIC_COST_TYPES: frozenset[BlockCostType] = frozenset(
|
||||
{
|
||||
BlockCostType.SECOND,
|
||||
BlockCostType.ITEMS,
|
||||
BlockCostType.COST_USD,
|
||||
BlockCostType.TOKENS,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class BlockCost(BaseModel):
|
||||
cost_amount: int
|
||||
cost_filter: BlockInput
|
||||
cost_type: BlockCostType
|
||||
# cost_divisor: interpret cost_amount as "credits per cost_divisor units".
|
||||
# Only meaningful for SECOND / ITEMS. TOKENS routes through TOKEN_COST
|
||||
# rate tables (per-model input/output/cache pricing) and ignores
|
||||
# cost_divisor entirely. Defaults to 1 so existing RUN/BYTE entries stay
|
||||
# point-wise. Example: cost_amount=1, cost_divisor=10 under SECOND means
|
||||
# "1 credit per 10 seconds of walltime".
|
||||
cost_divisor: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_amount: int,
|
||||
cost_type: BlockCostType = BlockCostType.RUN,
|
||||
cost_filter: Optional[BlockInput] = None,
|
||||
cost_divisor: int = 1,
|
||||
**data: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
cost_amount=cost_amount,
|
||||
cost_filter=cost_filter or {},
|
||||
cost_type=cost_type,
|
||||
cost_divisor=max(1, cost_divisor),
|
||||
**data,
|
||||
)
|
||||
|
||||
@@ -168,9 +205,31 @@ class BlockSchema(BaseModel):
|
||||
return cls.cached_jsonschema
|
||||
|
||||
@classmethod
|
||||
def validate_data(cls, data: BlockInput) -> str | None:
|
||||
def validate_data(
|
||||
cls,
|
||||
data: BlockInput,
|
||||
exclude_fields: set[str] | None = None,
|
||||
) -> str | None:
|
||||
schema = cls.jsonschema()
|
||||
if exclude_fields:
|
||||
# Drop the excluded fields from both the properties and the
|
||||
# ``required`` list so jsonschema doesn't flag them as missing.
|
||||
# Used by the dry-run path to skip credentials validation while
|
||||
# still validating the remaining block inputs.
|
||||
schema = {
|
||||
**schema,
|
||||
"properties": {
|
||||
k: v
|
||||
for k, v in schema.get("properties", {}).items()
|
||||
if k not in exclude_fields
|
||||
},
|
||||
"required": [
|
||||
r for r in schema.get("required", []) if r not in exclude_fields
|
||||
],
|
||||
}
|
||||
data = {k: v for k, v in data.items() if k not in exclude_fields}
|
||||
return json.validate_with_jsonschema(
|
||||
schema=cls.jsonschema(),
|
||||
schema=schema,
|
||||
data={k: v for k, v in data.items() if v is not None},
|
||||
)
|
||||
|
||||
@@ -311,6 +370,8 @@ class BlockSchema(BaseModel):
|
||||
"credentials_provider": [config.get("provider", "google")],
|
||||
"credentials_types": [config.get("type", "oauth2")],
|
||||
"credentials_scopes": config.get("scopes"),
|
||||
"is_auto_credential": True,
|
||||
"input_field_name": info["field_name"],
|
||||
}
|
||||
result[kwarg_name] = CredentialsFieldInfo.model_validate(
|
||||
auto_schema, by_alias=True
|
||||
@@ -421,19 +482,6 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
|
||||
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
_optimized_description: ClassVar[str | None] = None
|
||||
|
||||
def extra_runtime_cost(self, execution_stats: NodeExecutionStats) -> int:
|
||||
"""Return extra runtime cost to charge after this block run completes.
|
||||
|
||||
Called by the executor after a block finishes with COMPLETED status.
|
||||
The return value is the number of additional base-cost credits to
|
||||
charge beyond the single credit already collected by charge_usage
|
||||
at the start of execution. Defaults to 0 (no extra charges).
|
||||
|
||||
Override in blocks (e.g. OrchestratorBlock) that make multiple LLM
|
||||
calls within one run and should be billed per call.
|
||||
"""
|
||||
return 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str = "",
|
||||
@@ -717,11 +765,16 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
# (e.g. AgentExecutorBlock) get proper input validation.
|
||||
is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False)
|
||||
if is_dry_run:
|
||||
# Credential fields may be absent (LLM-built agents often skip
|
||||
# wiring them) or nullified earlier in the pipeline. Validate
|
||||
# the non-credential inputs against a schema with those fields
|
||||
# excluded — stripping only the data while keeping them in the
|
||||
# ``required`` list would falsely report ``'credentials' is a
|
||||
# required property``.
|
||||
cred_field_names = set(self.input_schema.get_credentials_fields().keys())
|
||||
non_cred_data = {
|
||||
k: v for k, v in input_data.items() if k not in cred_field_names
|
||||
}
|
||||
if error := self.input_schema.validate_data(non_cred_data):
|
||||
if error := self.input_schema.validate_data(
|
||||
input_data, exclude_fields=cred_field_names
|
||||
):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
@@ -735,6 +788,61 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Ensure auto-credential kwargs are present before we hand off to
|
||||
# run(). A missing auto-credential means the upstream field (e.g.
|
||||
# a Google Drive picker) didn't embed a _credentials_id, or the
|
||||
# executor couldn't resolve it. Without this guard, run() would
|
||||
# crash with a TypeError (missing required kwarg) or an opaque
|
||||
# AttributeError deep inside the provider SDK.
|
||||
#
|
||||
# Only raise when the field is ALSO not populated in input_data.
|
||||
# ``_acquire_auto_credentials`` intentionally skips setting the
|
||||
# kwarg in two legitimate cases — ``_credentials_id`` is ``None``
|
||||
# (chained from upstream) or the field is missing from
|
||||
# ``input_data`` at prep time (connected from upstream block).
|
||||
# In both cases the upstream block is expected to populate the
|
||||
# field value by execute time; raising here would break the
|
||||
# documented ``AgentGoogleDriveFileInputBlock`` chaining pattern.
|
||||
# Dry-run skips because the executor intentionally runs blocks
|
||||
# without resolved creds for schema validation.
|
||||
if not is_dry_run:
|
||||
for (
|
||||
kwarg_name,
|
||||
info,
|
||||
) in self.input_schema.get_auto_credentials_fields().items():
|
||||
kwargs.setdefault(kwarg_name, None)
|
||||
if kwargs[kwarg_name] is not None:
|
||||
continue
|
||||
# Upstream-chained pattern: the field was populated by a
|
||||
# prior node (e.g. AgentGoogleDriveFileInputBlock) whose
|
||||
# output carries a resolved ``_credentials_id``.
|
||||
# ``_acquire_auto_credentials`` deliberately doesn't set
|
||||
# the kwarg in that case because the value isn't available
|
||||
# at prep time; the executor fills it in before we reach
|
||||
# ``_execute``. Trust it if the ``_credentials_id`` KEY
|
||||
# is present — its value may be explicitly ``None`` in
|
||||
# the chained case (see sentry thread
|
||||
# PRRT_kwDOJKSTjM58sJfA). Checking truthiness here would
|
||||
# falsely preempt run() for every valid chained graph
|
||||
# that ships ``_credentials_id=None`` in the picker
|
||||
# object. Mirror ``_acquire_auto_credentials``'s own
|
||||
# skip rule, which treats ``cred_id is None`` as a
|
||||
# chained-skip signal.
|
||||
field_name = info["field_name"]
|
||||
field_value = input_data.get(field_name)
|
||||
if isinstance(field_value, dict) and "_credentials_id" in field_value:
|
||||
continue
|
||||
raise BlockExecutionError(
|
||||
message=(
|
||||
f"Missing credentials for '{kwarg_name}'. "
|
||||
"Select a file via the picker (which carries "
|
||||
"its credentials), or connect credentials for "
|
||||
"this block."
|
||||
),
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Use the validated input data
|
||||
async for output_name, output_data in self.run(
|
||||
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
|
||||
|
||||
@@ -171,7 +171,10 @@ class AgentExecutorBlock(Block):
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
extra_cost=event.stats.cost if event.stats else 0,
|
||||
# Sub-graph already debited each of its own nodes; we
|
||||
# roll up its total so graph_stats.cost reflects the
|
||||
# full sub-graph spend.
|
||||
reconciled_cost_delta=(event.stats.cost if event.stats else 0),
|
||||
extra_steps=event.stats.node_exec_count if event.stats else 0,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -4,11 +4,16 @@ Shared configuration for all AgentMail blocks.
|
||||
|
||||
from agentmail import AsyncAgentMail
|
||||
|
||||
from backend.sdk import APIKeyCredentials, ProviderBuilder, SecretStr
|
||||
from backend.sdk import APIKeyCredentials, BlockCostType, ProviderBuilder, SecretStr
|
||||
|
||||
# AgentMail is in beta with no published paid tier yet, but ~37 blocks
|
||||
# without any BLOCK_COSTS entry means they currently execute wallet-free.
|
||||
# 1 cr/call is a conservative interim floor so no AgentMail work leaks
|
||||
# past billing. Revisit once AgentMail publishes usage-based pricing.
|
||||
agent_mail = (
|
||||
ProviderBuilder("agent_mail")
|
||||
.with_api_key("AGENTMAIL_API_KEY", "AgentMail API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
21
autogpt_platform/backend/backend/blocks/ayrshare/_config.py
Normal file
21
autogpt_platform/backend/backend/blocks/ayrshare/_config.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Shared provider config for Ayrshare social-media blocks.
|
||||
|
||||
The "credential" exposed to blocks is the **per-user Ayrshare profile key**,
|
||||
not the org-level ``AYRSHARE_API_KEY``. Profile keys are provisioned per
|
||||
user by :class:`~backend.integrations.managed_providers.ayrshare.AyrshareManagedProvider`
|
||||
and stored in the normal credentials list with ``is_managed=True``, so every
|
||||
Ayrshare block fits the standard credential flow:
|
||||
|
||||
credentials: CredentialsMetaInput = ayrshare.credentials_field(...)
|
||||
|
||||
``run_block`` / ``resolve_block_credentials`` take care of the rest.
|
||||
|
||||
``with_managed_api_key()`` registers ``api_key`` as a supported auth type
|
||||
without the env-var-backed default credential that ``with_api_key()`` would
|
||||
create — the org-level ``AYRSHARE_API_KEY`` is the admin key and must never
|
||||
reach a block as a "profile key".
|
||||
"""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
ayrshare = ProviderBuilder("ayrshare").with_managed_api_key().build()
|
||||
18
autogpt_platform/backend/backend/blocks/ayrshare/_cost.py
Normal file
18
autogpt_platform/backend/backend/blocks/ayrshare/_cost.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from backend.sdk import BlockCost, BlockCostType
|
||||
|
||||
# Ayrshare is a subscription proxy ($149/mo Business). Per-post credit charges
|
||||
# prevent a single heavy user from absorbing the fixed cost and align with the
|
||||
# upload cost of each post variant.
|
||||
# cost_filter matches on input_data.is_video BEFORE run() executes, so the flag
|
||||
# has to be correct at input-eval time. Video-only platforms (YouTube, Snapchat)
|
||||
# override the base default to True; platforms that accept both (TikTok, etc.)
|
||||
# rely on the caller setting is_video explicitly for accurate billing.
|
||||
# First match wins in block_usage_cost, so list the video tier first.
|
||||
AYRSHARE_POST_COSTS = (
|
||||
BlockCost(
|
||||
cost_amount=5, cost_type=BlockCostType.RUN, cost_filter={"is_video": True}
|
||||
),
|
||||
BlockCost(
|
||||
cost_amount=2, cost_type=BlockCostType.RUN, cost_filter={"is_video": False}
|
||||
),
|
||||
)
|
||||
@@ -4,22 +4,25 @@ from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.blocks._base import BlockSchemaInput
|
||||
from backend.data.model import SchemaField, UserIntegrations
|
||||
from backend.data.model import CredentialsMetaInput, SchemaField
|
||||
from backend.integrations.ayrshare import AyrshareClient
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
|
||||
|
||||
async def get_profile_key(user_id: str):
|
||||
user_integrations: UserIntegrations = (
|
||||
await get_database_manager_async_client().get_user_integrations(user_id)
|
||||
)
|
||||
return user_integrations.managed_credentials.ayrshare_profile_key
|
||||
from ._config import ayrshare
|
||||
|
||||
|
||||
class BaseAyrshareInput(BlockSchemaInput):
|
||||
"""Base input model for Ayrshare social media posts with common fields."""
|
||||
|
||||
credentials: CredentialsMetaInput = ayrshare.credentials_field(
|
||||
description=(
|
||||
"Ayrshare profile credential. AutoGPT provisions this managed "
|
||||
"credential automatically — the user does not create it. After "
|
||||
"it's in place, the user links each social account via the "
|
||||
"Ayrshare SSO popup in the Builder."
|
||||
),
|
||||
)
|
||||
|
||||
post: str = SchemaField(
|
||||
description="The post text to be published", default="", advanced=False
|
||||
)
|
||||
@@ -29,7 +32,9 @@ class BaseAyrshareInput(BlockSchemaInput):
|
||||
advanced=False,
|
||||
)
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video", default=False, advanced=True
|
||||
description="Whether the media is a video. Set to True when uploading a video so billing applies the video tier.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
schedule_date: Optional[datetime] = SchemaField(
|
||||
description="UTC datetime for scheduling (YYYY-MM-DDThh:mm:ssZ)",
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToBlueskyBlock(Block):
|
||||
"""Block for posting to Bluesky with Bluesky-specific options."""
|
||||
|
||||
@@ -57,16 +61,10 @@ class PostToBlueskyBlock(Block):
|
||||
self,
|
||||
input_data: "PostToBlueskyBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Bluesky with Bluesky-specific options."""
|
||||
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -106,7 +104,7 @@ class PostToBlueskyBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
bluesky_options=bluesky_options if bluesky_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
CarouselItem,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, CarouselItem, create_ayrshare_client
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToFacebookBlock(Block):
|
||||
"""Block for posting to Facebook with Facebook-specific options."""
|
||||
|
||||
@@ -120,15 +119,10 @@ class PostToFacebookBlock(Block):
|
||||
self,
|
||||
input_data: "PostToFacebookBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Facebook with Facebook-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -204,7 +198,7 @@ class PostToFacebookBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
facebook_options=facebook_options if facebook_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToGMBBlock(Block):
|
||||
"""Block for posting to Google My Business with GMB-specific options."""
|
||||
|
||||
@@ -110,14 +114,13 @@ class PostToGMBBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: "PostToGMBBlock.Input", *, user_id: str, **kwargs
|
||||
self,
|
||||
input_data: "PostToGMBBlock.Input",
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
**kwargs
|
||||
) -> BlockOutput:
|
||||
"""Post to Google My Business with GMB-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -202,7 +205,7 @@ class PostToGMBBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
gmb_options=gmb_options if gmb_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -2,22 +2,21 @@ from typing import Any
|
||||
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
InstagramUserTag,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, InstagramUserTag, create_ayrshare_client
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToInstagramBlock(Block):
|
||||
"""Block for posting to Instagram with Instagram-specific options."""
|
||||
|
||||
@@ -112,15 +111,10 @@ class PostToInstagramBlock(Block):
|
||||
self,
|
||||
input_data: "PostToInstagramBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Instagram with Instagram-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -241,7 +235,7 @@ class PostToInstagramBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
instagram_options=instagram_options if instagram_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToLinkedInBlock(Block):
|
||||
"""Block for posting to LinkedIn with LinkedIn-specific options."""
|
||||
|
||||
@@ -112,15 +116,10 @@ class PostToLinkedInBlock(Block):
|
||||
self,
|
||||
input_data: "PostToLinkedInBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to LinkedIn with LinkedIn-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -214,7 +213,7 @@ class PostToLinkedInBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
linkedin_options=linkedin_options if linkedin_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
PinterestCarouselOption,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, PinterestCarouselOption, create_ayrshare_client
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToPinterestBlock(Block):
|
||||
"""Block for posting to Pinterest with Pinterest-specific options."""
|
||||
|
||||
@@ -92,15 +91,10 @@ class PostToPinterestBlock(Block):
|
||||
self,
|
||||
input_data: "PostToPinterestBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Pinterest with Pinterest-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -206,7 +200,7 @@ class PostToPinterestBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
pinterest_options=pinterest_options if pinterest_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToRedditBlock(Block):
|
||||
"""Block for posting to Reddit."""
|
||||
|
||||
@@ -35,12 +39,12 @@ class PostToRedditBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: "PostToRedditBlock.Input", *, user_id: str, **kwargs
|
||||
self,
|
||||
input_data: "PostToRedditBlock.Input",
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
**kwargs
|
||||
) -> BlockOutput:
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured."
|
||||
@@ -61,7 +65,7 @@ class PostToRedditBlock(Block):
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToSnapchatBlock(Block):
|
||||
"""Block for posting to Snapchat with Snapchat-specific options."""
|
||||
|
||||
@@ -31,6 +35,14 @@ class PostToSnapchatBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Snapchat is video-only; override the base default so the @cost filter
|
||||
# selects the 5-credit video tier instead of the 2-credit image tier.
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video (always True for Snapchat)",
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Snapchat-specific options
|
||||
story_type: str = SchemaField(
|
||||
description="Type of Snapchat content: 'story' (24-hour Stories), 'saved_story' (Saved Stories), or 'spotlight' (Spotlight posts)",
|
||||
@@ -62,15 +74,10 @@ class PostToSnapchatBlock(Block):
|
||||
self,
|
||||
input_data: "PostToSnapchatBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Snapchat with Snapchat-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -121,7 +128,7 @@ class PostToSnapchatBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
snapchat_options=snapchat_options if snapchat_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToTelegramBlock(Block):
|
||||
"""Block for posting to Telegram with Telegram-specific options."""
|
||||
|
||||
@@ -57,15 +61,10 @@ class PostToTelegramBlock(Block):
|
||||
self,
|
||||
input_data: "PostToTelegramBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Telegram with Telegram-specific validation."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -108,7 +107,7 @@ class PostToTelegramBlock(Block):
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToThreadsBlock(Block):
|
||||
"""Block for posting to Threads with Threads-specific options."""
|
||||
|
||||
@@ -50,15 +54,10 @@ class PostToThreadsBlock(Block):
|
||||
self,
|
||||
input_data: "PostToThreadsBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Threads with Threads-specific validation."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -103,7 +102,7 @@ class PostToThreadsBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
threads_options=threads_options if threads_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -2,15 +2,18 @@ from enum import Enum
|
||||
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
class TikTokVisibility(str, Enum):
|
||||
@@ -19,6 +22,7 @@ class TikTokVisibility(str, Enum):
|
||||
FOLLOWERS = "followers"
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToTikTokBlock(Block):
|
||||
"""Block for posting to TikTok with TikTok-specific options."""
|
||||
|
||||
@@ -113,14 +117,13 @@ class PostToTikTokBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: "PostToTikTokBlock.Input", *, user_id: str, **kwargs
|
||||
self,
|
||||
input_data: "PostToTikTokBlock.Input",
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to TikTok with TikTok-specific validation and options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -235,7 +238,7 @@ class PostToTikTokBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
tiktok_options=tiktok_options if tiktok_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToXBlock(Block):
|
||||
"""Block for posting to X / Twitter with Twitter-specific options."""
|
||||
|
||||
@@ -115,15 +119,10 @@ class PostToXBlock(Block):
|
||||
self,
|
||||
input_data: "PostToXBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to X / Twitter with enhanced X-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -233,7 +232,7 @@ class PostToXBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
twitter_options=twitter_options if twitter_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -3,15 +3,18 @@ from typing import Any
|
||||
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
class YouTubeVisibility(str, Enum):
|
||||
@@ -20,6 +23,7 @@ class YouTubeVisibility(str, Enum):
|
||||
UNLISTED = "unlisted"
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToYouTubeBlock(Block):
|
||||
"""Block for posting to YouTube with YouTube-specific options."""
|
||||
|
||||
@@ -39,6 +43,14 @@ class PostToYouTubeBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# YouTube is video-only; override the base default so the @cost filter
|
||||
# selects the 5-credit video tier instead of the 2-credit image tier.
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video (always True for YouTube)",
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# YouTube-specific required options
|
||||
title: str = SchemaField(
|
||||
description="Video title (max 100 chars, required). Cannot contain < or > characters.",
|
||||
@@ -137,16 +149,10 @@ class PostToYouTubeBlock(Block):
|
||||
self,
|
||||
input_data: "PostToYouTubeBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to YouTube with YouTube-specific validation and options."""
|
||||
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -302,7 +308,7 @@ class PostToYouTubeBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
youtube_options=youtube_options,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -4,21 +4,34 @@ Meeting BaaS bot (recording) blocks.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockCost,
|
||||
BlockCostType,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._api import MeetingBaasAPI
|
||||
from ._config import baas
|
||||
|
||||
# Meeting BaaS recording rate: $0.69 per hour.
|
||||
_MEETING_BAAS_USD_PER_SECOND = 0.69 / 3600
|
||||
|
||||
# Join bills a flat 30 cr commit (covers median short meeting);
|
||||
# FetchMeetingData bills the duration-scaled remainder from the
|
||||
# `duration_seconds` field on the API response. Long meetings no
|
||||
# longer under-bill.
|
||||
|
||||
|
||||
@cost(BlockCost(cost_type=BlockCostType.RUN, cost_amount=30))
|
||||
class BaasBotJoinMeetingBlock(Block):
|
||||
"""
|
||||
Deploy a bot immediately or at a scheduled start_time to join and record a meeting.
|
||||
@@ -134,6 +147,7 @@ class BaasBotLeaveMeetingBlock(Block):
|
||||
yield "left", left
|
||||
|
||||
|
||||
@cost(BlockCost(cost_type=BlockCostType.COST_USD, cost_amount=150))
|
||||
class BaasBotFetchMeetingDataBlock(Block):
|
||||
"""
|
||||
Pull MP4 URL, transcript & metadata for a completed meeting.
|
||||
@@ -176,9 +190,21 @@ class BaasBotFetchMeetingDataBlock(Block):
|
||||
include_transcripts=input_data.include_transcripts,
|
||||
)
|
||||
|
||||
bot_meta = data.get("bot_data", {}).get("bot", {}) or {}
|
||||
# Bill recording duration via COST_USD so multi-hour meetings
|
||||
# scale past the Join block's flat 30 cr deposit.
|
||||
duration_seconds = float(bot_meta.get("duration_seconds") or 0)
|
||||
if duration_seconds > 0:
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=duration_seconds * _MEETING_BAAS_USD_PER_SECOND,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
|
||||
yield "mp4_url", data.get("mp4", "")
|
||||
yield "transcript", data.get("bot_data", {}).get("transcripts", [])
|
||||
yield "metadata", data.get("bot_data", {}).get("bot", {})
|
||||
yield "metadata", bot_meta
|
||||
|
||||
|
||||
class BaasBotDeleteRecordingBlock(Block):
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
"""Unit tests for Meeting BaaS duration-based cost emission."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.baas.bots import (
|
||||
_MEETING_BAAS_USD_PER_SECOND,
|
||||
BaasBotFetchMeetingDataBlock,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="baas",
|
||||
title="Mock BaaS API Key",
|
||||
api_key=SecretStr("mock-baas-api-key"),
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
|
||||
def test_usd_per_second_derives_from_published_rate():
|
||||
"""$0.69/hour published rate → ~$0.000192/second."""
|
||||
assert _MEETING_BAAS_USD_PER_SECOND == pytest.approx(0.69 / 3600)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"duration_seconds, expected_usd",
|
||||
[
|
||||
(3600, 0.69), # 1 hour
|
||||
(1800, 0.345), # 30 min
|
||||
(0, None), # no recording → no emission
|
||||
(None, None), # missing duration field → no emission
|
||||
],
|
||||
)
|
||||
async def test_fetch_meeting_data_emits_duration_cost_usd(
|
||||
duration_seconds, expected_usd
|
||||
):
|
||||
"""FetchMeetingData extracts duration_seconds from bot metadata and
|
||||
emits provider_cost / cost_usd scaled by the published $0.69/hr rate.
|
||||
Emission is skipped when duration is 0 or missing.
|
||||
"""
|
||||
block = BaasBotFetchMeetingDataBlock()
|
||||
|
||||
bot_meta = {"id": "bot-xyz"}
|
||||
if duration_seconds is not None:
|
||||
bot_meta["duration_seconds"] = duration_seconds
|
||||
|
||||
mock_api = AsyncMock()
|
||||
mock_api.get_meeting_data.return_value = {
|
||||
"mp4": "https://example/recording.mp4",
|
||||
"bot_data": {"bot": bot_meta, "transcripts": []},
|
||||
}
|
||||
|
||||
captured: list[NodeExecutionStats] = []
|
||||
with (
|
||||
patch("backend.blocks.baas.bots.MeetingBaasAPI", return_value=mock_api),
|
||||
patch.object(block, "merge_stats", side_effect=captured.append),
|
||||
):
|
||||
outputs = []
|
||||
async for name, val in block.run(
|
||||
block.input_schema(
|
||||
credentials={
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
},
|
||||
bot_id="bot-xyz",
|
||||
include_transcripts=False,
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
outputs.append((name, val))
|
||||
|
||||
# Always yields the 3 outputs regardless of duration.
|
||||
names = [n for n, _ in outputs]
|
||||
assert "mp4_url" in names and "metadata" in names
|
||||
|
||||
if expected_usd is None:
|
||||
assert captured == []
|
||||
else:
|
||||
assert len(captured) == 1
|
||||
assert captured[0].provider_cost == pytest.approx(expected_usd)
|
||||
assert captured[0].provider_cost_type == "cost_usd"
|
||||
@@ -3,6 +3,6 @@ from backend.sdk import BlockCostType, ProviderBuilder
|
||||
bannerbear = (
|
||||
ProviderBuilder("bannerbear")
|
||||
.with_api_key("BANNERBEAR_API_KEY", "Bannerbear API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.with_base_cost(3, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -431,6 +432,7 @@ class ClaudeCodeBlock(Block):
|
||||
# The JSON output contains the result
|
||||
output_data = json.loads(raw_output)
|
||||
response = output_data.get("result", raw_output)
|
||||
self._record_cli_cost(output_data)
|
||||
|
||||
# Build conversation history entry
|
||||
turn_entry = f"User: {prompt}\nClaude: {response}"
|
||||
@@ -484,6 +486,23 @@ class ClaudeCodeBlock(Block):
|
||||
escaped = prompt.replace("'", "'\"'\"'")
|
||||
return f"'{escaped}'"
|
||||
|
||||
def _record_cli_cost(self, output_data: dict) -> None:
|
||||
"""Feed Claude Code CLI's `total_cost_usd` to the COST_USD resolver.
|
||||
|
||||
The CLI rolls up Anthropic LLM + internal tool-call spend into
|
||||
``total_cost_usd`` on its JSON response; piping it through
|
||||
``merge_stats`` lets the wallet reflect real spend.
|
||||
"""
|
||||
total_cost_usd = output_data.get("total_cost_usd")
|
||||
if total_cost_usd is None:
|
||||
return
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(total_cost_usd),
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
|
||||
106
autogpt_platform/backend/backend/blocks/claude_code_cost_test.py
Normal file
106
autogpt_platform/backend/backend/blocks/claude_code_cost_test.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Unit tests for ClaudeCodeBlock COST_USD billing migration.
|
||||
|
||||
Verifies:
|
||||
- Block emits provider_cost / cost_usd when Claude Code CLI returns
|
||||
total_cost_usd.
|
||||
- block_usage_cost resolves the COST_USD entry to the expected ceil(usd *
|
||||
cost_amount) credit charge.
|
||||
- Missing total_cost_usd gracefully produces provider_cost=None (no bill).
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks._base import BlockCostType
|
||||
from backend.blocks.claude_code import ClaudeCodeBlock
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.executor.utils import block_usage_cost
|
||||
|
||||
|
||||
def test_claude_code_registered_as_cost_usd_150():
|
||||
"""Sanity: BLOCK_COSTS holds the COST_USD, 150 cr/$ entry."""
|
||||
entries = BLOCK_COSTS[ClaudeCodeBlock]
|
||||
assert len(entries) == 1
|
||||
entry = entries[0]
|
||||
assert entry.cost_type == BlockCostType.COST_USD
|
||||
assert entry.cost_amount == 150
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"total_cost_usd, expected_credits",
|
||||
[
|
||||
(0.50, 75), # $0.50 × 150 = 75 cr
|
||||
(1.00, 150), # $1.00 × 150 = 150 cr
|
||||
(0.0134, 3), # ceil(0.0134 × 150) = ceil(2.01) = 3
|
||||
(2.00, 300), # $2 × 150 = 300 cr
|
||||
(0.001, 1), # ceil(0.001 × 150) = ceil(0.15) = 1 — no 0-cr leak on
|
||||
# sub-cent runs
|
||||
],
|
||||
)
|
||||
def test_cost_usd_resolver_applies_150_multiplier(total_cost_usd, expected_credits):
|
||||
"""block_usage_cost with cost_usd stats returns ceil(usd * 150)."""
|
||||
block = ClaudeCodeBlock()
|
||||
# cost_filter requires matching e2b_credentials; supply the ones the
|
||||
# registration uses so _is_cost_filter_match accepts the input.
|
||||
entry = BLOCK_COSTS[ClaudeCodeBlock][0]
|
||||
input_data = {"e2b_credentials": entry.cost_filter["e2b_credentials"]}
|
||||
stats = NodeExecutionStats(
|
||||
provider_cost=total_cost_usd,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
cost, matching_filter = block_usage_cost(
|
||||
block=block, input_data=input_data, stats=stats
|
||||
)
|
||||
assert cost == expected_credits
|
||||
assert matching_filter == entry.cost_filter
|
||||
|
||||
|
||||
def test_cost_usd_resolver_returns_zero_when_stats_missing_cost():
|
||||
"""Pre-flight (no stats) or unbilled run (provider_cost None) → 0."""
|
||||
block = ClaudeCodeBlock()
|
||||
entry = BLOCK_COSTS[ClaudeCodeBlock][0]
|
||||
input_data = {"e2b_credentials": entry.cost_filter["e2b_credentials"]}
|
||||
# No stats at all → pre-flight path, returns 0.
|
||||
pre_cost, _ = block_usage_cost(block=block, input_data=input_data)
|
||||
assert pre_cost == 0
|
||||
# Stats present but no provider_cost → resolver can't bill.
|
||||
stats = NodeExecutionStats()
|
||||
post_cost, _ = block_usage_cost(block=block, input_data=input_data, stats=stats)
|
||||
assert post_cost == 0
|
||||
|
||||
|
||||
def test_record_cli_cost_emits_provider_cost_when_total_cost_present():
|
||||
"""``_record_cli_cost`` (the helper called from ``execute_claude_code``)
|
||||
must emit a single ``merge_stats`` with provider_cost + cost_usd tag
|
||||
when the CLI JSON payload carries ``total_cost_usd``.
|
||||
"""
|
||||
block = ClaudeCodeBlock()
|
||||
captured: list[NodeExecutionStats] = []
|
||||
with patch.object(block, "merge_stats", side_effect=captured.append):
|
||||
block._record_cli_cost(
|
||||
{
|
||||
"result": "hello from claude",
|
||||
"total_cost_usd": 0.0421,
|
||||
"usage": {"input_tokens": 1234, "output_tokens": 56},
|
||||
}
|
||||
)
|
||||
|
||||
assert len(captured) == 1
|
||||
stats = captured[0]
|
||||
assert stats.provider_cost == pytest.approx(0.0421)
|
||||
assert stats.provider_cost_type == "cost_usd"
|
||||
|
||||
|
||||
def test_record_cli_cost_skips_merge_when_total_cost_absent():
|
||||
"""If the CLI payload lacks ``total_cost_usd`` (legacy / non-JSON
|
||||
output), ``_record_cli_cost`` must not call ``merge_stats`` — otherwise
|
||||
we'd pollute telemetry with a ``cost_usd`` emission that has no real
|
||||
cost attached.
|
||||
"""
|
||||
block = ClaudeCodeBlock()
|
||||
mock = MagicMock()
|
||||
with patch.object(block, "merge_stats", mock):
|
||||
block._record_cli_cost({"result": "hello"})
|
||||
mock.assert_not_called()
|
||||
@@ -151,6 +151,17 @@ class CodeGenerationBlock(Block):
|
||||
)
|
||||
self.execution_stats = NodeExecutionStats()
|
||||
|
||||
# GPT-5.1-Codex published pricing: $1.25 / 1M input, $10 / 1M output.
|
||||
_INPUT_USD_PER_1M = 1.25
|
||||
_OUTPUT_USD_PER_1M = 10.0
|
||||
|
||||
@staticmethod
|
||||
def _compute_token_usd(input_tokens: int, output_tokens: int) -> float:
|
||||
return (
|
||||
input_tokens * CodeGenerationBlock._INPUT_USD_PER_1M
|
||||
+ output_tokens * CodeGenerationBlock._OUTPUT_USD_PER_1M
|
||||
) / 1_000_000
|
||||
|
||||
async def call_codex(
|
||||
self,
|
||||
*,
|
||||
@@ -189,13 +200,15 @@ class CodeGenerationBlock(Block):
|
||||
response_id = response.id or ""
|
||||
|
||||
# Update usage stats
|
||||
self.execution_stats.input_token_count = (
|
||||
response.usage.input_tokens if response.usage else 0
|
||||
)
|
||||
self.execution_stats.output_token_count = (
|
||||
response.usage.output_tokens if response.usage else 0
|
||||
)
|
||||
input_tokens = response.usage.input_tokens if response.usage else 0
|
||||
output_tokens = response.usage.output_tokens if response.usage else 0
|
||||
self.execution_stats.input_token_count = input_tokens
|
||||
self.execution_stats.output_token_count = output_tokens
|
||||
self.execution_stats.llm_call_count += 1
|
||||
self.execution_stats.provider_cost = self._compute_token_usd(
|
||||
input_tokens, output_tokens
|
||||
)
|
||||
self.execution_stats.provider_cost_type = "cost_usd"
|
||||
|
||||
return CodexCallResult(
|
||||
response=text_output,
|
||||
|
||||
226
autogpt_platform/backend/backend/blocks/cost_leak_fixes_test.py
Normal file
226
autogpt_platform/backend/backend/blocks/cost_leak_fixes_test.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""Coverage tests for the cost-leak fixes in this PR.
|
||||
|
||||
Each block's ``run()`` / helper emits provider_cost + cost_usd (or items)
|
||||
via merge_stats so the post-flight resolver bills real provider spend.
|
||||
Tests here drive that emission path directly so a regression on any one
|
||||
block surfaces immediately.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks._base import BlockCostType
|
||||
from backend.blocks.ai_condition import AIConditionBlock
|
||||
from backend.data.block_cost_config import BLOCK_COSTS, LLM_COST
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats
|
||||
|
||||
# -------- AIConditionBlock registration --------
|
||||
|
||||
|
||||
def test_ai_condition_registered_under_llm_cost():
|
||||
"""AIConditionBlock was running wallet-free before this PR; verify it
|
||||
now resolves through the same per-model LLM_COST table as every other
|
||||
LLM block.
|
||||
"""
|
||||
assert BLOCK_COSTS[AIConditionBlock] is LLM_COST
|
||||
|
||||
|
||||
# -------- Pinecone insert ITEMS emission --------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pinecone_insert_emits_items_provider_cost():
|
||||
from backend.blocks.pinecone import PineconeInsertBlock
|
||||
|
||||
block = PineconeInsertBlock()
|
||||
captured: list[NodeExecutionStats] = []
|
||||
|
||||
class _FakeIndex:
|
||||
def upsert(self, **_):
|
||||
return None
|
||||
|
||||
class _FakePinecone:
|
||||
def __init__(self, *_, **__):
|
||||
pass
|
||||
|
||||
def Index(self, _name):
|
||||
return _FakeIndex()
|
||||
|
||||
with (
|
||||
patch("backend.blocks.pinecone.Pinecone", _FakePinecone),
|
||||
patch.object(block, "merge_stats", side_effect=captured.append),
|
||||
):
|
||||
input_data = block.input_schema(
|
||||
credentials={
|
||||
"id": "00000000-0000-0000-0000-000000000000",
|
||||
"provider": "pinecone",
|
||||
"type": "api_key",
|
||||
},
|
||||
index="my-index",
|
||||
chunks=["alpha", "beta", "gamma"],
|
||||
embeddings=[[0.1] * 4, [0.2] * 4, [0.3] * 4],
|
||||
namespace="",
|
||||
metadata={},
|
||||
)
|
||||
|
||||
creds = APIKeyCredentials(
|
||||
id="00000000-0000-0000-0000-000000000000",
|
||||
provider="pinecone",
|
||||
title="mock",
|
||||
api_key=SecretStr("mock-key"),
|
||||
expires_at=None,
|
||||
)
|
||||
outputs = [(n, v) async for n, v in block.run(input_data, credentials=creds)]
|
||||
|
||||
assert any(name == "upsert_response" for name, _ in outputs)
|
||||
assert len(captured) == 1
|
||||
stats = captured[0]
|
||||
assert stats.provider_cost == pytest.approx(3.0)
|
||||
assert stats.provider_cost_type == "items"
|
||||
|
||||
|
||||
# -------- Narration model-aware per-char rate --------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_id, expected_rate_per_char",
|
||||
[
|
||||
("eleven_flash_v2_5", 0.000167 * 0.5),
|
||||
("eleven_turbo_v2_5", 0.000167 * 0.5),
|
||||
("eleven_multilingual_v2", 0.000167 * 1.0),
|
||||
("eleven_turbo_v2", 0.000167 * 1.0),
|
||||
],
|
||||
)
|
||||
def test_narration_per_char_rate_scales_with_model(model_id, expected_rate_per_char):
|
||||
"""Drive VideoNarrationBlock._record_script_cost directly so a regression
|
||||
that drops the model-aware branching (e.g. hardcoding 1.0 cr/char for
|
||||
all models) makes this test fail.
|
||||
"""
|
||||
from backend.blocks.video.narration import VideoNarrationBlock
|
||||
|
||||
block = VideoNarrationBlock()
|
||||
captured: list[NodeExecutionStats] = []
|
||||
with patch.object(block, "merge_stats", side_effect=captured.append):
|
||||
block._record_script_cost("x" * 5000, model_id)
|
||||
|
||||
assert len(captured) == 1
|
||||
stats = captured[0]
|
||||
assert stats.provider_cost == pytest.approx(5000 * expected_rate_per_char)
|
||||
assert stats.provider_cost_type == "cost_usd"
|
||||
|
||||
|
||||
# -------- Perplexity None-guard on x-total-cost --------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"openrouter_cost, expect_type",
|
||||
[
|
||||
(0.0421, "cost_usd"), # concrete positive USD → tagged
|
||||
(None, None), # header missing → no tag (keeps gap observable)
|
||||
(0.0, None), # zero → no tag (wouldn't bill anything anyway)
|
||||
],
|
||||
)
|
||||
def test_perplexity_record_openrouter_cost_tags_only_on_concrete_value(
|
||||
openrouter_cost, expect_type
|
||||
):
|
||||
"""Drive PerplexityBlock._record_openrouter_cost directly to verify the
|
||||
None/0 guard. A regression that tags cost_usd unconditionally would
|
||||
silently floor the user's bill to 0 via the resolver — this test
|
||||
would catch it.
|
||||
"""
|
||||
from backend.blocks.perplexity import PerplexityBlock
|
||||
|
||||
block = PerplexityBlock()
|
||||
with patch(
|
||||
"backend.blocks.perplexity.extract_openrouter_cost",
|
||||
return_value=openrouter_cost,
|
||||
):
|
||||
block._record_openrouter_cost(response=object())
|
||||
|
||||
assert block.execution_stats.provider_cost == openrouter_cost
|
||||
assert block.execution_stats.provider_cost_type == expect_type
|
||||
|
||||
|
||||
# -------- Codex COST_USD registration --------
|
||||
|
||||
|
||||
def test_codex_registered_as_cost_usd_150():
|
||||
from backend.blocks.codex import CodeGenerationBlock
|
||||
|
||||
entries = BLOCK_COSTS[CodeGenerationBlock]
|
||||
assert len(entries) == 1
|
||||
entry = entries[0]
|
||||
assert entry.cost_type == BlockCostType.COST_USD
|
||||
assert entry.cost_amount == 150
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_tokens, output_tokens, expected_usd",
|
||||
[
|
||||
# GPT-5.1-Codex: $1.25 / 1M input, $10 / 1M output.
|
||||
(1_000_000, 0, 1.25),
|
||||
(0, 1_000_000, 10.0),
|
||||
(100_000, 10_000, 0.225), # 0.125 + 0.100
|
||||
(0, 0, 0.0),
|
||||
],
|
||||
)
|
||||
def test_codex_computes_provider_cost_usd_from_token_counts(
|
||||
input_tokens, output_tokens, expected_usd
|
||||
):
|
||||
"""Drive CodeGenerationBlock._compute_token_usd directly. A regression
|
||||
to the wrong rate constants (e.g. swapping the $1.25 input rate for
|
||||
GPT-4o's $2.50) would fail this test.
|
||||
"""
|
||||
from backend.blocks.codex import CodeGenerationBlock
|
||||
|
||||
assert CodeGenerationBlock._compute_token_usd(
|
||||
input_tokens, output_tokens
|
||||
) == pytest.approx(expected_usd)
|
||||
|
||||
|
||||
# -------- ClaudeCode COST_USD registration sanity (already tested in claude_code_cost_test.py) --------
|
||||
|
||||
|
||||
# -------- Perplexity COST_USD registration for all 3 tiers --------
|
||||
|
||||
|
||||
def test_perplexity_sonar_all_tiers_registered_as_cost_usd_150():
|
||||
from backend.blocks.perplexity import PerplexityBlock
|
||||
|
||||
entries = BLOCK_COSTS[PerplexityBlock]
|
||||
# 3 tiers (SONAR, SONAR_PRO, SONAR_DEEP_RESEARCH) all COST_USD 150.
|
||||
assert len(entries) == 3
|
||||
for entry in entries:
|
||||
assert entry.cost_type == BlockCostType.COST_USD
|
||||
assert entry.cost_amount == 150
|
||||
|
||||
|
||||
# -------- Narration COST_USD registration --------
|
||||
|
||||
|
||||
def test_narration_registered_as_cost_usd_150():
|
||||
from backend.blocks.video.narration import VideoNarrationBlock
|
||||
|
||||
entries = BLOCK_COSTS[VideoNarrationBlock]
|
||||
assert len(entries) == 1
|
||||
assert entries[0].cost_type == BlockCostType.COST_USD
|
||||
assert entries[0].cost_amount == 150
|
||||
|
||||
|
||||
# -------- Pinecone registrations --------
|
||||
|
||||
|
||||
def test_pinecone_registrations():
|
||||
from backend.blocks.pinecone import (
|
||||
PineconeInitBlock,
|
||||
PineconeInsertBlock,
|
||||
PineconeQueryBlock,
|
||||
)
|
||||
|
||||
assert BLOCK_COSTS[PineconeInitBlock][0].cost_type == BlockCostType.RUN
|
||||
assert BLOCK_COSTS[PineconeQueryBlock][0].cost_type == BlockCostType.RUN
|
||||
# Insert scales with item count.
|
||||
assert BLOCK_COSTS[PineconeInsertBlock][0].cost_type == BlockCostType.ITEMS
|
||||
assert BLOCK_COSTS[PineconeInsertBlock][0].cost_amount == 1
|
||||
@@ -19,6 +19,10 @@ class DataForSeoClient:
|
||||
trusted_origins=["https://api.dataforseo.com"],
|
||||
raise_for_status=False,
|
||||
)
|
||||
# USD cost reported by DataForSEO on the most recent successful call.
|
||||
# Populated by keyword_suggestions / related_keywords so the caller
|
||||
# can surface it via NodeExecutionStats.provider_cost for billing.
|
||||
self.last_cost_usd: float = 0.0
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""Generate the authorization header using Basic Auth."""
|
||||
@@ -97,6 +101,9 @@ class DataForSeoClient:
|
||||
if data.get("tasks") and len(data["tasks"]) > 0:
|
||||
task = data["tasks"][0]
|
||||
if task.get("status_code") == 20000: # Success code
|
||||
# DataForSEO reports per-task USD cost; stash it so callers
|
||||
# can populate NodeExecutionStats.provider_cost.
|
||||
self.last_cost_usd = float(task.get("cost") or 0.0)
|
||||
return task.get("result", [])
|
||||
else:
|
||||
error_msg = task.get("status_message", "Task failed")
|
||||
@@ -174,6 +181,9 @@ class DataForSeoClient:
|
||||
if data.get("tasks") and len(data["tasks"]) > 0:
|
||||
task = data["tasks"][0]
|
||||
if task.get("status_code") == 20000: # Success code
|
||||
# DataForSEO reports per-task USD cost; stash it so callers
|
||||
# can populate NodeExecutionStats.provider_cost.
|
||||
self.last_cost_usd = float(task.get("cost") or 0.0)
|
||||
return task.get("result", [])
|
||||
else:
|
||||
error_msg = task.get("status_message", "Task failed")
|
||||
|
||||
@@ -12,6 +12,11 @@ dataforseo = (
|
||||
password_env_var="DATAFORSEO_PASSWORD",
|
||||
title="DataForSEO Credentials",
|
||||
)
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
# DataForSEO reports USD cost per task (e.g. $0.001/keyword returned).
|
||||
# DataForSeoClient stashes it on last_cost_usd; each block emits it via
|
||||
# merge_stats so the COST_USD resolver bills against real spend.
|
||||
# 1000 platform credits per USD → 1 credit per $0.001 (≈ 1 credit/
|
||||
# returned keyword on the standard tier).
|
||||
.with_base_cost(1000, BlockCostType.COST_USD)
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ DataForSEO Google Keyword Suggestions block.
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -110,8 +111,10 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
test_output=[
|
||||
(
|
||||
"suggestion",
|
||||
lambda x: hasattr(x, "keyword")
|
||||
and x.keyword == "digital marketing strategy",
|
||||
lambda x: (
|
||||
hasattr(x, "keyword")
|
||||
and x.keyword == "digital marketing strategy"
|
||||
),
|
||||
),
|
||||
("suggestions", lambda x: isinstance(x, list) and len(x) == 1),
|
||||
("total_count", 1),
|
||||
@@ -167,6 +170,16 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
|
||||
# DataForSEO reports per-task USD cost on the response. Feed it
|
||||
# into NodeExecutionStats so the COST_USD resolver bills the
|
||||
# real provider spend at reconciliation time.
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=client.last_cost_usd,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
|
||||
@@ -4,6 +4,7 @@ DataForSEO Google Related Keywords block.
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -177,6 +178,16 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
|
||||
# DataForSEO reports per-task USD cost on the response. Feed it
|
||||
# into NodeExecutionStats so the COST_USD resolver bills the
|
||||
# real provider spend at reconciliation time.
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=client.last_cost_usd,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
|
||||
@@ -11,6 +11,11 @@ exa = (
|
||||
ProviderBuilder("exa")
|
||||
.with_api_key("EXA_API_KEY", "Exa API Key")
|
||||
.with_webhook_manager(ExaWebhookManager)
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
# Exa returns `cost_dollars.total` on every response and ExaSearchBlock
|
||||
# (plus ~45 sibling blocks that share this provider config) already
|
||||
# populates NodeExecutionStats.provider_cost with it. Bill 100 credits
|
||||
# per USD (~$0.01/credit): cheap searches stay at 1–2 credits, a Deep
|
||||
# Research run at $0.20 lands at 20 credits, matching provider spend.
|
||||
.with_base_cost(100, BlockCostType.COST_USD)
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
class AnswerCitation(BaseModel):
|
||||
@@ -111,3 +112,7 @@ class ExaAnswerBlock(Block):
|
||||
yield "citations", citations
|
||||
for citation in citations:
|
||||
yield "citation", citation
|
||||
|
||||
# Current SDK AnswerResponse dataclass omits cost_dollars; helper
|
||||
# no-ops today, but keeps billing wired when exa_py adds the field.
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -23,6 +22,7 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
class CodeContextResponse(BaseModel):
|
||||
@@ -118,9 +118,5 @@ class ExaCodeContextBlock(Block):
|
||||
yield "search_time", context.search_time
|
||||
yield "output_tokens", context.output_tokens
|
||||
|
||||
# Parse cost_dollars (API returns as string, e.g. "0.005")
|
||||
try:
|
||||
cost_usd = float(context.cost_dollars)
|
||||
self.merge_stats(NodeExecutionStats(provider_cost=cost_usd))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
# API returns costDollars as a bare numeric string like "0.005".
|
||||
merge_exa_cost(self, data)
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Optional
|
||||
from exa_py import AsyncExa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -24,6 +23,7 @@ from .helpers import (
|
||||
HighlightSettings,
|
||||
LivecrawlTypes,
|
||||
SummarySettings,
|
||||
merge_exa_cost,
|
||||
)
|
||||
|
||||
|
||||
@@ -224,6 +224,4 @@ class ExaContentsBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
@@ -143,7 +143,9 @@ class TestExaContentsCostTracking:
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
block.Input(
|
||||
urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT
|
||||
), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
@@ -172,7 +174,9 @@ class TestExaContentsCostTracking:
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
block.Input(
|
||||
urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT
|
||||
), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
@@ -201,7 +205,9 @@ class TestExaContentsCostTracking:
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
block.Input(
|
||||
urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT
|
||||
), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
@@ -297,7 +303,9 @@ class TestExaSimilarCostTracking:
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
block.Input(
|
||||
url="https://example.com", credentials=TEST_CREDENTIALS_INPUT
|
||||
), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
@@ -326,7 +334,9 @@ class TestExaSimilarCostTracking:
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
block.Input(
|
||||
url="https://example.com", credentials=TEST_CREDENTIALS_INPUT
|
||||
), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
from backend.sdk import BaseModel, MediaFileType, SchemaField
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import BaseModel, Block, MediaFileType, SchemaField
|
||||
|
||||
|
||||
class LivecrawlTypes(str, Enum):
|
||||
@@ -319,7 +320,7 @@ class CostDollars(BaseModel):
|
||||
|
||||
# Helper functions for payload processing
|
||||
def process_text_field(
|
||||
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None]
|
||||
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None],
|
||||
) -> Optional[Union[bool, Dict[str, Any]]]:
|
||||
"""Process text field for API payload."""
|
||||
if text is None:
|
||||
@@ -400,7 +401,7 @@ def process_contents_settings(contents: Optional[ContentSettings]) -> Dict[str,
|
||||
|
||||
|
||||
def process_context_field(
|
||||
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None]
|
||||
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None],
|
||||
) -> Optional[Union[bool, Dict[str, int]]]:
|
||||
"""Process context field for API payload."""
|
||||
if context is None:
|
||||
@@ -448,3 +449,65 @@ def add_optional_fields(
|
||||
payload[api_field] = value.value
|
||||
else:
|
||||
payload[api_field] = value
|
||||
|
||||
|
||||
def extract_exa_cost_usd(response: Any) -> Optional[float]:
|
||||
"""Return ``cost_dollars.total`` (USD) from an Exa SDK response, or None.
|
||||
|
||||
Handles dataclass/pydantic responses (``response.cost_dollars.total``),
|
||||
dicts with camelCase keys (``response["costDollars"]["total"]``), dicts
|
||||
with snake_case keys, and bare numeric strings. Returns None whenever the
|
||||
shape is missing cost info — the caller then skips merge_stats.
|
||||
"""
|
||||
if response is None:
|
||||
return None
|
||||
|
||||
# Dataclass / pydantic: response.cost_dollars
|
||||
cost_obj = getattr(response, "cost_dollars", None)
|
||||
|
||||
# Dict payloads: try both camelCase and snake_case
|
||||
if cost_obj is None and isinstance(response, dict):
|
||||
cost_obj = response.get("costDollars") or response.get("cost_dollars")
|
||||
|
||||
if cost_obj is None:
|
||||
return None
|
||||
|
||||
# Already a scalar (code_context endpoint returns a string)
|
||||
if isinstance(cost_obj, (int, float)):
|
||||
return max(0.0, float(cost_obj))
|
||||
if isinstance(cost_obj, str):
|
||||
try:
|
||||
return max(0.0, float(cost_obj))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
# Nested object/dict: grab the `total` field
|
||||
total = getattr(cost_obj, "total", None)
|
||||
if total is None and isinstance(cost_obj, dict):
|
||||
total = cost_obj.get("total")
|
||||
|
||||
if total is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return max(0.0, float(total))
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def merge_exa_cost(block: Block, response: Any) -> None:
|
||||
"""Pull ``cost_dollars.total`` off an Exa response and merge it into stats.
|
||||
|
||||
No-op when the response shape has no cost info (e.g. webset CRUD where
|
||||
the SDK does not expose per-call pricing) — emission happens only when
|
||||
Exa actually reports a USD amount.
|
||||
"""
|
||||
cost_usd = extract_exa_cost_usd(response)
|
||||
if cost_usd is None:
|
||||
return
|
||||
block.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=cost_usd,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
"""Unit tests for exa/helpers cost-extraction + merge helpers."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.exa.helpers import extract_exa_cost_usd, merge_exa_cost
|
||||
from backend.data.model import NodeExecutionStats
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response, expected",
|
||||
[
|
||||
# Dataclass / SimpleNamespace with cost_dollars.total
|
||||
(SimpleNamespace(cost_dollars=SimpleNamespace(total=0.05)), 0.05),
|
||||
# Dict camelCase
|
||||
({"costDollars": {"total": 0.10}}, 0.10),
|
||||
# Dict snake_case
|
||||
({"cost_dollars": {"total": 0.07}}, 0.07),
|
||||
# code_context endpoint shape: plain numeric string
|
||||
(SimpleNamespace(cost_dollars="0.005"), 0.005),
|
||||
# Scalar float on cost_dollars directly
|
||||
(SimpleNamespace(cost_dollars=0.02), 0.02),
|
||||
# Scalar int on cost_dollars
|
||||
(SimpleNamespace(cost_dollars=3), 3.0),
|
||||
# Missing cost info — returns None
|
||||
({}, None),
|
||||
(SimpleNamespace(other="foo"), None),
|
||||
(None, None),
|
||||
# Nested total=None
|
||||
(SimpleNamespace(cost_dollars=SimpleNamespace(total=None)), None),
|
||||
# Invalid numeric string
|
||||
(SimpleNamespace(cost_dollars="not-a-number"), None),
|
||||
# Negative values clamp to 0
|
||||
(SimpleNamespace(cost_dollars=SimpleNamespace(total=-1.0)), 0.0),
|
||||
],
|
||||
)
|
||||
def test_extract_exa_cost_usd_handles_all_shapes(response, expected):
|
||||
assert extract_exa_cost_usd(response) == expected
|
||||
|
||||
|
||||
def test_merge_exa_cost_emits_stats_when_cost_present():
|
||||
block = MagicMock()
|
||||
response = SimpleNamespace(cost_dollars=SimpleNamespace(total=0.0421))
|
||||
merge_exa_cost(block, response)
|
||||
|
||||
block.merge_stats.assert_called_once()
|
||||
stats: NodeExecutionStats = block.merge_stats.call_args.args[0]
|
||||
assert stats.provider_cost == pytest.approx(0.0421)
|
||||
assert stats.provider_cost_type == "cost_usd"
|
||||
|
||||
|
||||
def test_merge_exa_cost_noops_when_no_cost():
|
||||
"""Webset CRUD endpoints don't surface cost_dollars today — the helper
|
||||
must silently skip instead of emitting a 0-cost telemetry record."""
|
||||
block = MagicMock()
|
||||
merge_exa_cost(block, SimpleNamespace(other_field="nothing"))
|
||||
block.merge_stats.assert_not_called()
|
||||
|
||||
|
||||
def test_merge_exa_cost_noops_when_response_is_none():
|
||||
block = MagicMock()
|
||||
merge_exa_cost(block, None)
|
||||
block.merge_stats.assert_not_called()
|
||||
@@ -12,7 +12,6 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -26,6 +25,7 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
class ResearchModel(str, Enum):
|
||||
@@ -233,11 +233,7 @@ class ExaCreateResearchBlock(Block):
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=research.cost_dollars.total
|
||||
)
|
||||
)
|
||||
merge_exa_cost(self, research)
|
||||
return
|
||||
|
||||
await asyncio.sleep(check_interval)
|
||||
@@ -352,9 +348,7 @@ class ExaGetResearchBlock(Block):
|
||||
yield "cost_searches", research.cost_dollars.num_searches
|
||||
yield "cost_pages", research.cost_dollars.num_pages
|
||||
yield "cost_reasoning_tokens", research.cost_dollars.reasoning_tokens
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=research.cost_dollars.total)
|
||||
)
|
||||
merge_exa_cost(self, research)
|
||||
|
||||
yield "error_message", research.error
|
||||
|
||||
@@ -441,9 +435,7 @@ class ExaWaitForResearchBlock(Block):
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=research.cost_dollars.total)
|
||||
)
|
||||
merge_exa_cost(self, research)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -21,6 +20,7 @@ from .helpers import (
|
||||
ContentSettings,
|
||||
CostDollars,
|
||||
ExaSearchResults,
|
||||
merge_exa_cost,
|
||||
process_contents_settings,
|
||||
)
|
||||
|
||||
@@ -207,6 +207,4 @@ class ExaSearchBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -20,6 +19,7 @@ from .helpers import (
|
||||
ContentSettings,
|
||||
CostDollars,
|
||||
ExaSearchResults,
|
||||
merge_exa_cost,
|
||||
process_contents_settings,
|
||||
)
|
||||
|
||||
@@ -168,6 +168,4 @@ class ExaFindSimilarBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
@@ -39,6 +39,7 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
class SearchEntityType(str, Enum):
|
||||
@@ -394,6 +395,7 @@ class ExaCreateWebsetBlock(Block):
|
||||
metadata=input_data.metadata,
|
||||
)
|
||||
)
|
||||
merge_exa_cost(self, webset)
|
||||
|
||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||
|
||||
@@ -404,6 +406,7 @@ class ExaCreateWebsetBlock(Block):
|
||||
timeout=input_data.polling_timeout,
|
||||
poll_interval=5,
|
||||
)
|
||||
merge_exa_cost(self, final_webset)
|
||||
completion_time = time.time() - start_time
|
||||
|
||||
item_count = 0
|
||||
@@ -479,6 +482,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
||||
|
||||
try:
|
||||
webset = await aexa.websets.get(id=input_data.external_id)
|
||||
merge_exa_cost(self, webset)
|
||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||
|
||||
yield "webset", webset_result
|
||||
@@ -501,6 +505,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
||||
metadata=input_data.metadata,
|
||||
)
|
||||
)
|
||||
merge_exa_cost(self, webset)
|
||||
|
||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||
|
||||
@@ -555,6 +560,7 @@ class ExaUpdateWebsetBlock(Block):
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload)
|
||||
merge_exa_cost(self, sdk_webset)
|
||||
|
||||
status_str = (
|
||||
sdk_webset.status.value
|
||||
@@ -566,8 +572,9 @@ class ExaUpdateWebsetBlock(Block):
|
||||
yield "status", status_str
|
||||
yield "external_id", sdk_webset.external_id
|
||||
yield "metadata", sdk_webset.metadata or {}
|
||||
yield "updated_at", (
|
||||
sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else ""
|
||||
yield (
|
||||
"updated_at",
|
||||
(sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else ""),
|
||||
)
|
||||
|
||||
|
||||
@@ -621,6 +628,7 @@ class ExaListWebsetsBlock(Block):
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
websets_data = [
|
||||
w.model_dump(by_alias=True, exclude_none=True) for w in response.data
|
||||
@@ -679,6 +687,7 @@ class ExaGetWebsetBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, sdk_webset)
|
||||
|
||||
status_str = (
|
||||
sdk_webset.status.value
|
||||
@@ -706,11 +715,13 @@ class ExaGetWebsetBlock(Block):
|
||||
yield "enrichments", enrichments_data
|
||||
yield "monitors", monitors_data
|
||||
yield "metadata", sdk_webset.metadata or {}
|
||||
yield "created_at", (
|
||||
sdk_webset.created_at.isoformat() if sdk_webset.created_at else ""
|
||||
yield (
|
||||
"created_at",
|
||||
(sdk_webset.created_at.isoformat() if sdk_webset.created_at else ""),
|
||||
)
|
||||
yield "updated_at", (
|
||||
sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else ""
|
||||
yield (
|
||||
"updated_at",
|
||||
(sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else ""),
|
||||
)
|
||||
|
||||
|
||||
@@ -749,6 +760,7 @@ class ExaDeleteWebsetBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_webset = await aexa.websets.delete(id=input_data.webset_id)
|
||||
merge_exa_cost(self, deleted_webset)
|
||||
|
||||
status_str = (
|
||||
deleted_webset.status.value
|
||||
@@ -799,6 +811,7 @@ class ExaCancelWebsetBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
canceled_webset = await aexa.websets.cancel(id=input_data.webset_id)
|
||||
merge_exa_cost(self, canceled_webset)
|
||||
|
||||
status_str = (
|
||||
canceled_webset.status.value
|
||||
@@ -969,6 +982,7 @@ class ExaPreviewWebsetBlock(Block):
|
||||
payload["entity"] = entity
|
||||
|
||||
sdk_preview = await aexa.websets.preview(params=payload)
|
||||
merge_exa_cost(self, sdk_preview)
|
||||
|
||||
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
||||
|
||||
@@ -1052,6 +1066,7 @@ class ExaWebsetStatusBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
|
||||
status = (
|
||||
webset.status.value
|
||||
@@ -1186,6 +1201,7 @@ class ExaWebsetSummaryBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
|
||||
# Extract basic info
|
||||
webset_id = webset.id
|
||||
@@ -1214,6 +1230,7 @@ class ExaWebsetSummaryBlock(Block):
|
||||
items_response = await aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||
)
|
||||
merge_exa_cost(self, items_response)
|
||||
sample_items_data = [
|
||||
item.model_dump(by_alias=True, exclude_none=True)
|
||||
for item in items_response.data
|
||||
@@ -1363,6 +1380,7 @@ class ExaWebsetReadyCheckBlock(Block):
|
||||
|
||||
# Get webset details
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
|
||||
status = (
|
||||
webset.status.value
|
||||
|
||||
@@ -25,6 +25,7 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
# Mirrored model for stability
|
||||
@@ -205,6 +206,7 @@ class ExaCreateEnrichmentBlock(Block):
|
||||
sdk_enrichment = await aexa.websets.enrichments.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
merge_exa_cost(self, sdk_enrichment)
|
||||
|
||||
enrichment_id = sdk_enrichment.id
|
||||
status = (
|
||||
@@ -226,6 +228,7 @@ class ExaCreateEnrichmentBlock(Block):
|
||||
current_enrich = await aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=enrichment_id
|
||||
)
|
||||
merge_exa_cost(self, current_enrich)
|
||||
current_status = (
|
||||
current_enrich.status.value
|
||||
if hasattr(current_enrich.status, "value")
|
||||
@@ -235,6 +238,7 @@ class ExaCreateEnrichmentBlock(Block):
|
||||
if current_status in ["completed", "failed", "cancelled"]:
|
||||
# Estimate items from webset searches
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.progress:
|
||||
@@ -332,6 +336,7 @@ class ExaGetEnrichmentBlock(Block):
|
||||
sdk_enrichment = await aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
merge_exa_cost(self, sdk_enrichment)
|
||||
|
||||
enrichment = WebsetEnrichmentModel.from_sdk(sdk_enrichment)
|
||||
|
||||
@@ -425,6 +430,7 @@ class ExaUpdateEnrichmentBlock(Block):
|
||||
try:
|
||||
response = await Requests().patch(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
# PATCH /websets/{id}/enrichments/{id} doesn't return costDollars.
|
||||
|
||||
yield "enrichment_id", data.get("id", "")
|
||||
yield "status", data.get("status", "")
|
||||
@@ -477,6 +483,7 @@ class ExaDeleteEnrichmentBlock(Block):
|
||||
deleted_enrichment = await aexa.websets.enrichments.delete(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
merge_exa_cost(self, deleted_enrichment)
|
||||
|
||||
yield "enrichment_id", deleted_enrichment.id
|
||||
yield "success", "true"
|
||||
@@ -528,12 +535,14 @@ class ExaCancelEnrichmentBlock(Block):
|
||||
canceled_enrichment = await aexa.websets.enrichments.cancel(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
merge_exa_cost(self, canceled_enrichment)
|
||||
|
||||
# Try to estimate how many items were enriched before cancellation
|
||||
items_enriched = 0
|
||||
items_response = await aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=100
|
||||
)
|
||||
merge_exa_cost(self, items_response)
|
||||
|
||||
for sdk_item in items_response.data:
|
||||
# Check if this enrichment is present
|
||||
|
||||
@@ -29,6 +29,7 @@ from backend.sdk import (
|
||||
|
||||
from ._config import exa
|
||||
from ._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
# Mirrored model for stability - don't use SDK types directly in block outputs
|
||||
@@ -297,6 +298,7 @@ class ExaCreateImportBlock(Block):
|
||||
sdk_import = await aexa.websets.imports.create(
|
||||
params=payload, csv_data=input_data.csv_data
|
||||
)
|
||||
merge_exa_cost(self, sdk_import)
|
||||
|
||||
import_obj = ImportModel.from_sdk(sdk_import)
|
||||
|
||||
@@ -361,6 +363,7 @@ class ExaGetImportBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id)
|
||||
merge_exa_cost(self, sdk_import)
|
||||
|
||||
import_obj = ImportModel.from_sdk(sdk_import)
|
||||
|
||||
@@ -430,6 +433,7 @@ class ExaListImportsBlock(Block):
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
# Convert SDK imports to our stable models
|
||||
imports = [ImportModel.from_sdk(i) for i in response.data]
|
||||
@@ -477,6 +481,7 @@ class ExaDeleteImportBlock(Block):
|
||||
deleted_import = await aexa.websets.imports.delete(
|
||||
import_id=input_data.import_id
|
||||
)
|
||||
merge_exa_cost(self, deleted_import)
|
||||
|
||||
yield "import_id", deleted_import.id
|
||||
yield "success", "true"
|
||||
@@ -599,7 +604,7 @@ class ExaExportWebsetBlock(Block):
|
||||
try:
|
||||
all_items = []
|
||||
|
||||
# Use SDK's list_all iterator to fetch items
|
||||
# list_all paginates internally; cost_dollars is not surfaced per-page
|
||||
item_iterator = aexa.websets.items.list_all(
|
||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||
)
|
||||
|
||||
@@ -30,6 +30,7 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
# Mirrored model for enrichment results
|
||||
@@ -181,6 +182,7 @@ class ExaGetWebsetItemBlock(Block):
|
||||
sdk_item = await aexa.websets.items.get(
|
||||
webset_id=input_data.webset_id, id=input_data.item_id
|
||||
)
|
||||
merge_exa_cost(self, sdk_item)
|
||||
|
||||
item = WebsetItemModel.from_sdk(sdk_item)
|
||||
|
||||
@@ -293,6 +295,7 @@ class ExaListWebsetItemsBlock(Block):
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
items = [WebsetItemModel.from_sdk(item) for item in response.data]
|
||||
|
||||
@@ -343,6 +346,7 @@ class ExaDeleteWebsetItemBlock(Block):
|
||||
deleted_item = await aexa.websets.items.delete(
|
||||
webset_id=input_data.webset_id, id=input_data.item_id
|
||||
)
|
||||
merge_exa_cost(self, deleted_item)
|
||||
|
||||
yield "item_id", deleted_item.id
|
||||
yield "success", "true"
|
||||
@@ -404,6 +408,7 @@ class ExaBulkWebsetItemsBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
all_items: List[WebsetItemModel] = []
|
||||
# list_all paginates internally; cost_dollars is not surfaced per-page
|
||||
item_iterator = aexa.websets.items.list_all(
|
||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||
)
|
||||
@@ -476,6 +481,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
|
||||
entity_type = "unknown"
|
||||
if webset.searches:
|
||||
@@ -498,6 +504,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
||||
items_response = await aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||
)
|
||||
merge_exa_cost(self, items_response)
|
||||
# Convert to our stable models
|
||||
sample_items = [
|
||||
WebsetItemModel.from_sdk(item) for item in items_response.data
|
||||
@@ -574,6 +581,7 @@ class ExaGetNewItemsBlock(Block):
|
||||
cursor=input_data.since_cursor,
|
||||
limit=input_data.max_items,
|
||||
)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
# Convert SDK items to our stable models
|
||||
new_items = [WebsetItemModel.from_sdk(item) for item in response.data]
|
||||
|
||||
@@ -25,6 +25,7 @@ from backend.sdk import (
|
||||
|
||||
from ._config import exa
|
||||
from ._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
# Mirrored model for stability - don't use SDK types directly in block outputs
|
||||
@@ -321,6 +322,7 @@ class ExaCreateMonitorBlock(Block):
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_monitor = await aexa.websets.monitors.create(params=payload)
|
||||
merge_exa_cost(self, sdk_monitor)
|
||||
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
|
||||
@@ -385,6 +387,7 @@ class ExaGetMonitorBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_monitor = await aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||
merge_exa_cost(self, sdk_monitor)
|
||||
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
|
||||
@@ -479,6 +482,7 @@ class ExaUpdateMonitorBlock(Block):
|
||||
sdk_monitor = await aexa.websets.monitors.update(
|
||||
monitor_id=input_data.monitor_id, params=payload
|
||||
)
|
||||
merge_exa_cost(self, sdk_monitor)
|
||||
|
||||
# Convert to our stable model
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
@@ -525,6 +529,7 @@ class ExaDeleteMonitorBlock(Block):
|
||||
deleted_monitor = await aexa.websets.monitors.delete(
|
||||
monitor_id=input_data.monitor_id
|
||||
)
|
||||
merge_exa_cost(self, deleted_monitor)
|
||||
|
||||
yield "monitor_id", deleted_monitor.id
|
||||
yield "success", "true"
|
||||
@@ -586,6 +591,7 @@ class ExaListMonitorsBlock(Block):
|
||||
limit=input_data.limit,
|
||||
webset_id=input_data.webset_id,
|
||||
)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
# Convert SDK monitors to our stable models
|
||||
monitors = [MonitorModel.from_sdk(m) for m in response.data]
|
||||
|
||||
@@ -25,6 +25,7 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
# Import WebsetItemModel for use in enrichment samples
|
||||
# This is safe as websets_items doesn't import from websets_polling
|
||||
@@ -126,6 +127,7 @@ class ExaWaitForWebsetBlock(Block):
|
||||
timeout=input_data.timeout,
|
||||
poll_interval=input_data.check_interval,
|
||||
)
|
||||
merge_exa_cost(self, final_webset)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
@@ -165,6 +167,7 @@ class ExaWaitForWebsetBlock(Block):
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current webset status
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
current_status = (
|
||||
webset.status.value
|
||||
if hasattr(webset.status, "value")
|
||||
@@ -210,6 +213,7 @@ class ExaWaitForWebsetBlock(Block):
|
||||
# Timeout reached
|
||||
elapsed = time.time() - start_time
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
final_status = (
|
||||
webset.status.value
|
||||
if hasattr(webset.status, "value")
|
||||
@@ -348,6 +352,7 @@ class ExaWaitForSearchBlock(Block):
|
||||
search = await aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
merge_exa_cost(self, search)
|
||||
|
||||
# Extract status
|
||||
status = (
|
||||
@@ -404,6 +409,7 @@ class ExaWaitForSearchBlock(Block):
|
||||
search = await aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
merge_exa_cost(self, search)
|
||||
final_status = (
|
||||
search.status.value
|
||||
if hasattr(search.status, "value")
|
||||
@@ -506,6 +512,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
||||
enrichment = await aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
merge_exa_cost(self, enrichment)
|
||||
|
||||
# Extract status
|
||||
status = (
|
||||
@@ -523,16 +530,20 @@ class ExaWaitForEnrichmentBlock(Block):
|
||||
items_enriched = 0
|
||||
|
||||
if input_data.sample_results and status == "completed":
|
||||
sample_data, items_enriched = (
|
||||
await self._get_sample_enrichments(
|
||||
input_data.webset_id, input_data.enrichment_id, aexa
|
||||
)
|
||||
(
|
||||
sample_data,
|
||||
items_enriched,
|
||||
) = await self._get_sample_enrichments(
|
||||
input_data.webset_id, input_data.enrichment_id, aexa
|
||||
)
|
||||
|
||||
yield "enrichment_id", input_data.enrichment_id
|
||||
yield "final_status", status
|
||||
yield "items_enriched", items_enriched
|
||||
yield "enrichment_title", enrichment.title or enrichment.description or ""
|
||||
yield (
|
||||
"enrichment_title",
|
||||
enrichment.title or enrichment.description or "",
|
||||
)
|
||||
yield "elapsed_time", elapsed
|
||||
if input_data.sample_results:
|
||||
yield "sample_data", sample_data
|
||||
@@ -551,6 +562,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
||||
enrichment = await aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
merge_exa_cost(self, enrichment)
|
||||
final_status = (
|
||||
enrichment.status.value
|
||||
if hasattr(enrichment.status, "value")
|
||||
@@ -576,6 +588,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
||||
"""Get sample enriched data and count."""
|
||||
# Get a few items to see enrichment results using SDK
|
||||
response = await aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
sample_data: list[SampleEnrichmentModel] = []
|
||||
enriched_count = 0
|
||||
|
||||
@@ -24,6 +24,7 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
# Mirrored model for stability
|
||||
@@ -320,6 +321,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
||||
sdk_search = await aexa.websets.searches.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
merge_exa_cost(self, sdk_search)
|
||||
|
||||
search_id = sdk_search.id
|
||||
status = (
|
||||
@@ -353,6 +355,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
||||
current_search = await aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=search_id
|
||||
)
|
||||
merge_exa_cost(self, current_search)
|
||||
current_status = (
|
||||
current_search.status.value
|
||||
if hasattr(current_search.status, "value")
|
||||
@@ -445,6 +448,7 @@ class ExaGetWebsetSearchBlock(Block):
|
||||
sdk_search = await aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
merge_exa_cost(self, sdk_search)
|
||||
|
||||
search = WebsetSearchModel.from_sdk(sdk_search)
|
||||
|
||||
@@ -526,6 +530,7 @@ class ExaCancelWebsetSearchBlock(Block):
|
||||
canceled_search = await aexa.websets.searches.cancel(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
merge_exa_cost(self, canceled_search)
|
||||
|
||||
# Extract items found before cancellation
|
||||
items_found = 0
|
||||
@@ -605,6 +610,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
||||
|
||||
# Get webset to check existing searches
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
|
||||
# Look for existing search with same query
|
||||
existing_search = None
|
||||
@@ -639,6 +645,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
||||
sdk_search = await aexa.websets.searches.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
merge_exa_cost(self, sdk_search)
|
||||
|
||||
search = WebsetSearchModel.from_sdk(sdk_search)
|
||||
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
# Firecrawl bills in its own credits (1 credit ≈ $0.001). Each block's
|
||||
# run() estimates USD spend from the operation (pages scraped, limit,
|
||||
# credits_used on ExtractResponse) and merge_stats populates
|
||||
# NodeExecutionStats.provider_cost before billing reconciliation. 1000
|
||||
# platform credits per USD means 1 platform credit per Firecrawl credit
|
||||
# — roughly matches our existing per-call tier for single-page scrape.
|
||||
firecrawl = (
|
||||
ProviderBuilder("firecrawl")
|
||||
.with_api_key("FIRECRAWL_API_KEY", "Firecrawl API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.with_base_cost(1000, BlockCostType.COST_USD)
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from firecrawl import FirecrawlApp
|
||||
from firecrawl.v2.types import ScrapeOptions
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -86,6 +87,14 @@ class FirecrawlCrawlBlock(Block):
|
||||
wait_for=input_data.wait_for,
|
||||
),
|
||||
)
|
||||
# Firecrawl bills 1 credit (~$0.001) per crawled page. crawl_result.data
|
||||
# is the list of scraped pages actually returned.
|
||||
pages = len(crawl_result.data) if crawl_result.data else 0
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=pages * 0.001, provider_cost_type="cost_usd"
|
||||
)
|
||||
)
|
||||
yield "data", crawl_result.data
|
||||
|
||||
for data in crawl_result.data:
|
||||
|
||||
@@ -2,25 +2,22 @@ from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockCost,
|
||||
BlockCostType,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
from ._config import firecrawl
|
||||
|
||||
|
||||
@cost(BlockCost(2, BlockCostType.RUN))
|
||||
class FirecrawlExtractBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
@@ -74,4 +71,13 @@ class FirecrawlExtractBlock(Block):
|
||||
block_id=self.id,
|
||||
) from e
|
||||
|
||||
# Firecrawl surfaces actual credit spend on extract responses
|
||||
# (credits_used). 1 Firecrawl credit ≈ $0.001.
|
||||
credits_used = getattr(extract_result, "credits_used", None) or 0
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=credits_used * 0.001,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
yield "data", extract_result.data
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -50,6 +51,10 @@ class FirecrawlMapWebsiteBlock(Block):
|
||||
map_result = app.map(
|
||||
url=input_data.url,
|
||||
)
|
||||
# Firecrawl bills 1 credit (~$0.001) per map request.
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=0.001, provider_cost_type="cost_usd")
|
||||
)
|
||||
|
||||
# Convert SearchResult objects to dicts
|
||||
results_data = [
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -81,6 +82,11 @@ class FirecrawlScrapeBlock(Block):
|
||||
max_age=input_data.max_age,
|
||||
wait_for=input_data.wait_for,
|
||||
)
|
||||
# Firecrawl bills 1 credit (~$0.001) per scraped page; scrape is a
|
||||
# single-page operation.
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=0.001, provider_cost_type="cost_usd")
|
||||
)
|
||||
yield "data", scrape_result
|
||||
|
||||
for f in input_data.formats:
|
||||
|
||||
@@ -4,6 +4,7 @@ from firecrawl import FirecrawlApp
|
||||
from firecrawl.v2.types import ScrapeOptions
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -68,6 +69,17 @@ class FirecrawlSearchBlock(Block):
|
||||
wait_for=input_data.wait_for,
|
||||
),
|
||||
)
|
||||
# Firecrawl bills per returned web result (~1 credit each). The
|
||||
# SearchResponse structure exposes `.web` when scrape_options was
|
||||
# requested; fall back to `limit` as an upper bound estimate.
|
||||
web_results = getattr(scrape_result, "web", None) or []
|
||||
billed_units = max(len(web_results), 1)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=billed_units * 0.001,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
yield "data", scrape_result
|
||||
if hasattr(scrape_result, "web") and scrape_result.web:
|
||||
for site in scrape_result.web:
|
||||
|
||||
@@ -133,10 +133,21 @@ def GoogleDriveFileField(
|
||||
if allowed_mime_types:
|
||||
picker_config["allowed_mime_types"] = list(allowed_mime_types)
|
||||
|
||||
agent_builder_hint = (
|
||||
"At runtime, feed this from an AgentGoogleDriveFileInputBlock with "
|
||||
"matching allowed_views. NEVER hardcode a file ID in input_default "
|
||||
"(including one parsed from a Drive URL the user pasted in chat) — "
|
||||
"only the picker attaches the _credentials_id needed for auth."
|
||||
)
|
||||
|
||||
return SchemaField(
|
||||
default=None,
|
||||
title=title,
|
||||
description=description,
|
||||
description=(
|
||||
f"{description.rstrip('.')}. {agent_builder_hint}"
|
||||
if description
|
||||
else agent_builder_hint
|
||||
),
|
||||
placeholder=placeholder or "Select from Google Drive",
|
||||
# Use google-drive-picker format so frontend renders existing component
|
||||
format="google-drive-picker",
|
||||
|
||||
129
autogpt_platform/backend/backend/blocks/google/sheets_test.py
Normal file
129
autogpt_platform/backend/backend/blocks/google/sheets_test.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Edge-case tests for Google Sheets block credential handling.
|
||||
|
||||
These pin the contract for the systemic auto-credential None-guard in
|
||||
``Block._execute()``: any block with an auto-credential field (via
|
||||
``GoogleDriveFileField`` etc.) that's called without resolved
|
||||
credentials must surface a clean, user-facing ``BlockExecutionError``
|
||||
— never a wrapped ``TypeError`` (missing required kwarg) or
|
||||
``AttributeError`` deep in the provider SDK.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.google.sheets import GoogleSheetsReadBlock
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sheets_read_missing_credentials_yields_clean_error():
|
||||
"""Valid spreadsheet but no resolved credentials -> the systemic
|
||||
None-guard in ``Block._execute()`` yields a ``Missing credentials``
|
||||
error before ``run()`` is entered."""
|
||||
block = GoogleSheetsReadBlock()
|
||||
input_data = {
|
||||
"spreadsheet": {
|
||||
"id": "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms",
|
||||
"name": "Test Spreadsheet",
|
||||
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||
},
|
||||
"range": "Sheet1!A1:B2",
|
||||
}
|
||||
|
||||
with pytest.raises(BlockExecutionError, match="Missing credentials"):
|
||||
async for _ in block.execute(input_data):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sheets_read_no_spreadsheet_still_hits_credentials_guard():
|
||||
"""When neither spreadsheet nor credentials are present, the
|
||||
credentials guard fires first (it runs before we hand off to
|
||||
``run()``). The user-facing message should still be the clean
|
||||
``Missing credentials`` one, not an opaque ``TypeError``."""
|
||||
block = GoogleSheetsReadBlock()
|
||||
input_data = {"range": "Sheet1!A1:B2"} # no spreadsheet, no credentials
|
||||
|
||||
with pytest.raises(BlockExecutionError, match="Missing credentials"):
|
||||
async for _ in block.execute(input_data):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sheets_read_upstream_chained_value_skips_guard(mocker):
|
||||
"""A spreadsheet value chained in from an upstream input block (e.g.
|
||||
``AgentGoogleDriveFileInputBlock``) carries a resolved
|
||||
``_credentials_id`` that ``_acquire_auto_credentials`` didn't have
|
||||
visibility into at prep time. The systemic None-guard must NOT
|
||||
preempt run() in that case — otherwise every chained Drive-picker
|
||||
pattern crashes with a bogus ``Missing credentials`` error.
|
||||
|
||||
We short-circuit past the guard by patching the Google API client
|
||||
build; any error that escapes from run() is fine as long as the
|
||||
``Missing credentials`` message never surfaces."""
|
||||
# Patch out the real Google Sheets client build so we don't hit the
|
||||
# network and can detect we reached the provider SDK.
|
||||
mocker.patch(
|
||||
"backend.blocks.google.sheets.build",
|
||||
side_effect=RuntimeError("api-boundary-reached"),
|
||||
)
|
||||
|
||||
block = GoogleSheetsReadBlock()
|
||||
input_data = {
|
||||
"spreadsheet": {
|
||||
"_credentials_id": "upstream-chained-cred-id",
|
||||
"id": "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms",
|
||||
"name": "Upstream-chained sheet",
|
||||
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||
},
|
||||
"range": "Sheet1!A1:B2",
|
||||
}
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
async for _ in block.execute(input_data):
|
||||
pass
|
||||
|
||||
# The guard should skip (chained data present) and let us reach run(),
|
||||
# which then hits the patched provider-SDK boundary. A "Missing
|
||||
# credentials" error here would mean the None-guard broke the
|
||||
# documented AgentGoogleDriveFileInputBlock chaining pattern.
|
||||
assert "Missing credentials" not in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sheets_read_upstream_chained_with_explicit_none_cred_id_skips_guard(
|
||||
mocker,
|
||||
):
|
||||
"""Sentry HIGH regression (thread PRRT_kwDOJKSTjM58sJfA): the
|
||||
documented chained-upstream pattern ships the spreadsheet dict with
|
||||
``_credentials_id=None`` — the executor fills in the resolved id
|
||||
between prep time and ``run()``. The previous ``_base.py`` guard
|
||||
used ``field_value.get("_credentials_id")`` and treated the falsy
|
||||
``None`` value as "missing", raising ``BlockExecutionError`` on
|
||||
every chained graph.
|
||||
|
||||
Pin the contract: the presence of the ``_credentials_id`` key — not
|
||||
its truthiness — is what signals "trust the skip". A dict with
|
||||
``_credentials_id: None`` must not preempt run()."""
|
||||
mocker.patch(
|
||||
"backend.blocks.google.sheets.build",
|
||||
side_effect=RuntimeError("api-boundary-reached"),
|
||||
)
|
||||
|
||||
block = GoogleSheetsReadBlock()
|
||||
input_data = {
|
||||
"spreadsheet": {
|
||||
"_credentials_id": None, # explicit None — chained-upstream shape
|
||||
"id": "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms",
|
||||
"name": "Upstream-chained sheet (None cred_id)",
|
||||
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||
},
|
||||
"range": "Sheet1!A1:B2",
|
||||
}
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
async for _ in block.execute(input_data):
|
||||
pass
|
||||
|
||||
# The guard must not raise "Missing credentials" for this shape.
|
||||
# We expect to reach run() and hit the patched provider-SDK boundary.
|
||||
assert "Missing credentials" not in str(exc_info.value)
|
||||
@@ -737,7 +737,22 @@ class AgentGoogleDriveFileInputBlock(AgentInputBlock):
|
||||
)
|
||||
super().__init__(
|
||||
id="d3b32f15-6fd7-40e3-be52-e083f51b19a2",
|
||||
description="Block for selecting a file from Google Drive.",
|
||||
description=(
|
||||
"Agent-level input for a Google Drive file. REQUIRED for any "
|
||||
"agent that reads or writes a Drive file (Sheets, Docs, "
|
||||
"Slides, or generic Drive) — the picker is the only source "
|
||||
"of the _credentials_id needed at runtime, so consuming "
|
||||
"blocks cannot receive a hardcoded ID. Set allowed_views to "
|
||||
'match the consumer: ["SPREADSHEETS"] for Sheets, '
|
||||
'["DOCUMENTS"] for Docs, ["PRESENTATIONS"] for Slides '
|
||||
"(leave default for generic Drive). Wire `result` to the "
|
||||
"consumer block's Drive field and leave that field unset in "
|
||||
"the consumer's input_default. Example link to a Google "
|
||||
'Sheets block: {"source_name": "result", "sink_name": '
|
||||
'"spreadsheet"} (use "document" for Docs, "presentation" '
|
||||
"for Slides). Use one input block per distinct file; "
|
||||
"multiple consumers of the same file share it."
|
||||
),
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentGoogleDriveFileInputBlock.Input,
|
||||
output_schema=AgentGoogleDriveFileInputBlock.Output,
|
||||
|
||||
@@ -15,7 +15,7 @@ from backend.blocks.jina._auth import (
|
||||
JinaCredentialsInput,
|
||||
)
|
||||
from backend.blocks.search import GetRequest
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host
|
||||
|
||||
@@ -70,6 +70,13 @@ class SearchTheWebBlock(Block, GetRequest):
|
||||
block_id=self.id,
|
||||
) from e
|
||||
|
||||
# Jina Reader Search: $0.01/query on the paid tier. Fixed per-query
|
||||
# cost; routed through COST_USD so the platform cost log records
|
||||
# real USD spend (costMicrodollars) alongside the credit charge.
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=0.01, provider_cost_type="cost_usd")
|
||||
)
|
||||
|
||||
# Output the search results
|
||||
yield "results", results
|
||||
|
||||
@@ -128,10 +135,16 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
|
||||
try:
|
||||
content = await self.get_request(url, json=False, headers=headers)
|
||||
except HTTPClientError as e:
|
||||
yield "error", f"Client error ({e.status_code}) fetching {input_data.url}: {e}"
|
||||
yield (
|
||||
"error",
|
||||
f"Client error ({e.status_code}) fetching {input_data.url}: {e}",
|
||||
)
|
||||
return
|
||||
except HTTPServerError as e:
|
||||
yield "error", f"Server error ({e.status_code}) fetching {input_data.url}: {e}"
|
||||
yield (
|
||||
"error",
|
||||
f"Server error ({e.status_code}) fetching {input_data.url}: {e}",
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to fetch {input_data.url}: {e}"
|
||||
|
||||
@@ -206,6 +206,10 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
GROK_4_20_MULTI_AGENT = "x-ai/grok-4.20-multi-agent"
|
||||
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
|
||||
KIMI_K2 = "moonshotai/kimi-k2"
|
||||
KIMI_K2_0905 = "moonshotai/kimi-k2-0905"
|
||||
KIMI_K2_5 = "moonshotai/kimi-k2.5"
|
||||
KIMI_K2_6 = "moonshotai/kimi-k2.6"
|
||||
KIMI_K2_THINKING = "moonshotai/kimi-k2-thinking"
|
||||
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
|
||||
QWEN3_CODER = "qwen/qwen3-coder"
|
||||
# Z.ai (Zhipu) models
|
||||
@@ -646,6 +650,24 @@ MODEL_METADATA = {
|
||||
LlmModel.KIMI_K2: ModelMetadata(
|
||||
"open_router", 131000, 131000, "Kimi K2", "OpenRouter", "Moonshot AI", 1
|
||||
),
|
||||
LlmModel.KIMI_K2_0905: ModelMetadata(
|
||||
"open_router", 262144, 262144, "Kimi K2 0905", "OpenRouter", "Moonshot AI", 1
|
||||
),
|
||||
LlmModel.KIMI_K2_5: ModelMetadata(
|
||||
"open_router", 262144, 262144, "Kimi K2.5", "OpenRouter", "Moonshot AI", 1
|
||||
),
|
||||
LlmModel.KIMI_K2_6: ModelMetadata(
|
||||
"open_router", 262144, 262144, "Kimi K2.6", "OpenRouter", "Moonshot AI", 2
|
||||
),
|
||||
LlmModel.KIMI_K2_THINKING: ModelMetadata(
|
||||
"open_router",
|
||||
262144,
|
||||
262144,
|
||||
"Kimi K2 Thinking",
|
||||
"OpenRouter",
|
||||
"Moonshot AI",
|
||||
2,
|
||||
),
|
||||
LlmModel.QWEN3_235B_A22B_THINKING: ModelMetadata(
|
||||
"open_router",
|
||||
262144,
|
||||
@@ -1602,6 +1624,11 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
llm_call_count=retry_count + 1,
|
||||
llm_retry_count=retry_count,
|
||||
provider_cost=total_provider_cost,
|
||||
provider_cost_type=(
|
||||
"cost_usd"
|
||||
if total_provider_cost is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
yield "response", response_obj
|
||||
@@ -1623,6 +1650,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
llm_call_count=retry_count + 1,
|
||||
llm_retry_count=retry_count,
|
||||
provider_cost=total_provider_cost,
|
||||
provider_cost_type=(
|
||||
"cost_usd" if total_provider_cost is not None else None
|
||||
),
|
||||
)
|
||||
)
|
||||
yield "response", {"response": response_text}
|
||||
@@ -1657,7 +1687,12 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
# All retries exhausted or user-error break: persist accumulated cost so
|
||||
# the executor can still charge/report the spend even on failure.
|
||||
if total_provider_cost is not None:
|
||||
self.merge_stats(NodeExecutionStats(provider_cost=total_provider_cost))
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=total_provider_cost,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
raise RuntimeError(error_feedback_message)
|
||||
|
||||
def response_format_instructions(
|
||||
|
||||
@@ -376,20 +376,12 @@ class OrchestratorBlock(Block):
|
||||
re-raise carve-out for this reason.
|
||||
"""
|
||||
|
||||
def extra_runtime_cost(self, execution_stats: NodeExecutionStats) -> int:
|
||||
"""Charge one extra runtime cost per LLM call beyond the first.
|
||||
|
||||
In agent mode each iteration makes one LLM call. The first is already
|
||||
covered by charge_usage(); this returns the number of additional
|
||||
credits so the executor can bill the remaining calls post-completion.
|
||||
|
||||
SDK-mode exemption: when the block runs via _execute_tools_sdk_mode,
|
||||
the SDK manages its own conversation loop and only exposes aggregate
|
||||
usage. We hardcode llm_call_count=1 there (the SDK does not report a
|
||||
per-turn call count), so this method always returns 0 for SDK-mode
|
||||
executions. Per-iteration billing does not apply to SDK mode.
|
||||
"""
|
||||
return max(0, execution_stats.llm_call_count - 1)
|
||||
# OrchestratorBlock bills via BlockCostType.TOKENS + compute_token_credits,
|
||||
# which aggregates input_token_count / output_token_count / cache_read /
|
||||
# cache_creation across every LLM iteration into one post-flight charge.
|
||||
# The per-iteration flat-fee path (Block.extra_runtime_cost →
|
||||
# charge_extra_runtime_cost) would double-bill the same tokens, so
|
||||
# OrchestratorBlock deliberately inherits the base-class no-op default.
|
||||
|
||||
# MCP server name used by the Claude Code SDK execution mode. Keep in sync
|
||||
# with _create_graph_mcp_server and the MCP_PREFIX derivation in _execute_tools_sdk_mode.
|
||||
@@ -1189,10 +1181,14 @@ class OrchestratorBlock(Block):
|
||||
not execution_params.execution_context.dry_run
|
||||
and tool_node_stats.error is None
|
||||
):
|
||||
# Charge the sub-block for telemetry / wallet debit. The
|
||||
# return value is intentionally discarded: on_node_execution
|
||||
# above ran the sub-block against this graph's own
|
||||
# graph_stats_pair (manager.py:659-668), so its cost already
|
||||
# lands in graph_stats.cost on the sub-block's completion.
|
||||
# Re-merging here would double-count in telemetry / UI / audit.
|
||||
try:
|
||||
tool_cost, _ = await execution_processor.charge_node_usage(
|
||||
node_exec_entry,
|
||||
)
|
||||
await execution_processor.charge_node_usage(node_exec_entry)
|
||||
except InsufficientBalanceError:
|
||||
# IBE must propagate — see OrchestratorBlock class docstring.
|
||||
# Log the billing failure here so the discarded tool result
|
||||
@@ -1214,9 +1210,6 @@ class OrchestratorBlock(Block):
|
||||
"tool execution was successful",
|
||||
sink_node_id,
|
||||
)
|
||||
tool_cost = 0
|
||||
if tool_cost > 0:
|
||||
self.merge_stats(NodeExecutionStats(extra_cost=tool_cost))
|
||||
|
||||
# Get outputs from database after execution completes using database manager client
|
||||
node_outputs = await db_client.get_execution_outputs_by_node_exec_id(
|
||||
|
||||
@@ -13,6 +13,7 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.blocks.llm import extract_openrouter_cost
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
@@ -98,14 +99,23 @@ class PerplexityBlock(Block):
|
||||
return _sanitize_perplexity_model(v)
|
||||
|
||||
@classmethod
|
||||
def validate_data(cls, data: BlockInput) -> str | None:
|
||||
def validate_data(
|
||||
cls,
|
||||
data: BlockInput,
|
||||
exclude_fields: set[str] | None = None,
|
||||
) -> str | None:
|
||||
"""Sanitize the model field before JSON schema validation so that
|
||||
invalid values are replaced with the default instead of raising a
|
||||
BlockInputError."""
|
||||
BlockInputError.
|
||||
|
||||
Signature matches ``BlockSchema.validate_data`` (including the
|
||||
optional ``exclude_fields`` kwarg added for dry-run credential
|
||||
bypass) so Pyright doesn't flag this as an incompatible override.
|
||||
"""
|
||||
model_value = data.get("model")
|
||||
if model_value is not None:
|
||||
data["model"] = _sanitize_perplexity_model(model_value).value
|
||||
return super().validate_data(data)
|
||||
return super().validate_data(data, exclude_fields=exclude_fields)
|
||||
|
||||
system_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
@@ -230,12 +240,17 @@ class PerplexityBlock(Block):
|
||||
if "message" in choice and "annotations" in choice["message"]:
|
||||
annotations = choice["message"]["annotations"]
|
||||
|
||||
# Update execution stats
|
||||
# Update execution stats. ``execution_stats`` is instance state,
|
||||
# so always reset token counters — a response without ``usage``
|
||||
# must not leak a previous run's tokens into ``PlatformCostLog``.
|
||||
self.execution_stats.input_token_count = 0
|
||||
self.execution_stats.output_token_count = 0
|
||||
if response.usage:
|
||||
self.execution_stats.input_token_count = response.usage.prompt_tokens
|
||||
self.execution_stats.output_token_count = (
|
||||
response.usage.completion_tokens
|
||||
)
|
||||
self._record_openrouter_cost(response)
|
||||
|
||||
return {"response": response_content, "annotations": annotations or []}
|
||||
|
||||
@@ -243,6 +258,17 @@ class PerplexityBlock(Block):
|
||||
logger.error(f"Error calling Perplexity: {e}")
|
||||
raise
|
||||
|
||||
def _record_openrouter_cost(self, response: Any) -> None:
|
||||
"""Feed OpenRouter's ``x-total-cost`` USD into execution stats for
|
||||
the COST_USD resolver. Tag as ``cost_usd`` only when the value is
|
||||
concrete and positive — leaving it unset on None/0 keeps the
|
||||
billing gap observable instead of silently floored to 0.
|
||||
"""
|
||||
cost_usd = extract_openrouter_cost(response)
|
||||
self.execution_stats.provider_cost = cost_usd
|
||||
if cost_usd is not None and cost_usd > 0:
|
||||
self.execution_stats.provider_cost_type = "cost_usd"
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
@@ -14,6 +14,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -160,10 +161,13 @@ class PineconeQueryBlock(Block):
|
||||
combined_text = "\n\n".join(texts)
|
||||
|
||||
# Return both the raw matches and combined text
|
||||
yield "results", {
|
||||
"matches": results["matches"],
|
||||
"combined_text": combined_text,
|
||||
}
|
||||
yield (
|
||||
"results",
|
||||
{
|
||||
"matches": results["matches"],
|
||||
"combined_text": combined_text,
|
||||
},
|
||||
)
|
||||
yield "combined_results", combined_text
|
||||
|
||||
except Exception as e:
|
||||
@@ -228,6 +232,13 @@ class PineconeInsertBlock(Block):
|
||||
)
|
||||
idx.upsert(vectors=vectors, namespace=input_data.namespace)
|
||||
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(vectors)),
|
||||
provider_cost_type="items",
|
||||
)
|
||||
)
|
||||
|
||||
yield "upsert_response", "successfully upserted"
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
# 1 credit per 3 walltime seconds. Block walltime proxies for the
|
||||
# Browserbase session lifetime + the LLM call it issues. Interim until
|
||||
# the block emits real provider_cost (USD) via merge_stats and migrates
|
||||
# to COST_USD.
|
||||
stagehand = (
|
||||
ProviderBuilder("stagehand")
|
||||
.with_api_key("STAGEHAND_API_KEY", "Stagehand API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.with_base_cost(1, BlockCostType.SECOND, cost_divisor=3)
|
||||
.build()
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -27,7 +27,7 @@ from backend.blocks.video._utils import (
|
||||
strip_chapters_inplace,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
@@ -44,7 +44,8 @@ class VideoNarrationBlock(Block):
|
||||
)
|
||||
script: str = SchemaField(description="Narration script text")
|
||||
voice_id: str = SchemaField(
|
||||
description="ElevenLabs voice ID", default="21m00Tcm4TlvDq8ikWAM" # Rachel
|
||||
description="ElevenLabs voice ID",
|
||||
default="21m00Tcm4TlvDq8ikWAM", # Rachel
|
||||
)
|
||||
model_id: Literal[
|
||||
"eleven_multilingual_v2",
|
||||
@@ -124,6 +125,26 @@ class VideoNarrationBlock(Block):
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
# Models that consume 0.5 credits per character (v2.5 tier). All other
|
||||
# models default to 1.0 credit per character.
|
||||
_HALF_RATE_MODELS = {"eleven_flash_v2_5", "eleven_turbo_v2_5"}
|
||||
# ElevenLabs Starter plan: $5 / 30K credits = $0.000167 / credit.
|
||||
_USD_PER_CREDIT = 0.000167
|
||||
|
||||
def _record_script_cost(self, script: str, model_id: str) -> None:
|
||||
"""Emit provider_cost (USD) for the narration run so the COST_USD
|
||||
resolver can bill real ElevenLabs spend. Flash/Turbo v2.5 bill at
|
||||
half the char rate of Multilingual/Turbo v2.
|
||||
"""
|
||||
credits_per_char = 0.5 if model_id in self._HALF_RATE_MODELS else 1.0
|
||||
script_usd = len(script) * self._USD_PER_CREDIT * credits_per_char
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=script_usd,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_narration_audio(
|
||||
self, api_key: str, script: str, voice_id: str, model_id: str
|
||||
) -> bytes:
|
||||
@@ -223,6 +244,8 @@ class VideoNarrationBlock(Block):
|
||||
input_data.model_id,
|
||||
)
|
||||
|
||||
self._record_script_cost(input_data.script, input_data.model_id)
|
||||
|
||||
# Save audio to exec file path
|
||||
audio_filename = MediaFileType(f"{node_exec_id}_narration.mp3")
|
||||
audio_abspath = get_exec_file_path(
|
||||
|
||||
@@ -21,7 +21,7 @@ from backend.blocks.zerobounce._auth import (
|
||||
ZeroBounceCredentials,
|
||||
ZeroBounceCredentialsInput,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
@@ -140,20 +140,22 @@ class ValidateEmailsBlock(Block):
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"validate_email": lambda email, ip_address, credentials: ZBValidateResponse(
|
||||
data={
|
||||
"address": email,
|
||||
"status": ZBValidateStatus.valid,
|
||||
"sub_status": ZBValidateSubStatus.allowed,
|
||||
"account": "test",
|
||||
"domain": "test.com",
|
||||
"did_you_mean": None,
|
||||
"domain_age_days": None,
|
||||
"free_email": False,
|
||||
"mx_found": False,
|
||||
"mx_record": None,
|
||||
"smtp_provider": None,
|
||||
}
|
||||
"validate_email": lambda email, ip_address, credentials: (
|
||||
ZBValidateResponse(
|
||||
data={
|
||||
"address": email,
|
||||
"status": ZBValidateStatus.valid,
|
||||
"sub_status": ZBValidateSubStatus.allowed,
|
||||
"account": "test",
|
||||
"domain": "test.com",
|
||||
"did_you_mean": None,
|
||||
"domain_age_days": None,
|
||||
"free_email": False,
|
||||
"mx_found": False,
|
||||
"mx_record": None,
|
||||
"smtp_provider": None,
|
||||
}
|
||||
)
|
||||
)
|
||||
},
|
||||
)
|
||||
@@ -176,6 +178,13 @@ class ValidateEmailsBlock(Block):
|
||||
input_data.email, input_data.ip_address, credentials
|
||||
)
|
||||
|
||||
# ZeroBounce bills $0.008 per validated email on the paid tier.
|
||||
# Routed through COST_USD so platform cost telemetry captures real
|
||||
# USD spend; the resolver still bills 2 credits per call.
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=0.008, provider_cost_type="cost_usd")
|
||||
)
|
||||
|
||||
response_model = Response(**response.__dict__)
|
||||
|
||||
yield "response", response_model
|
||||
|
||||
364
autogpt_platform/backend/backend/copilot/baseline/reasoning.py
Normal file
364
autogpt_platform/backend/backend/copilot/baseline/reasoning.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""Extended-thinking wire support for the baseline (OpenRouter) path.
|
||||
|
||||
OpenRouter routes that support extended thinking (Anthropic Claude and
|
||||
Moonshot Kimi today) expose reasoning through non-OpenAI extension fields
|
||||
that the OpenAI Python SDK doesn't model:
|
||||
|
||||
* ``reasoning`` (legacy string) — enabled by ``include_reasoning: true``.
|
||||
* ``reasoning_content`` — DeepSeek / some OpenRouter routes.
|
||||
* ``reasoning_details`` — structured list shipped with the unified
|
||||
``reasoning`` request param.
|
||||
|
||||
This module keeps the wire-level concerns in one place:
|
||||
|
||||
* :class:`OpenRouterDeltaExtension` validates the extension dict pulled off
|
||||
``ChoiceDelta.model_extra`` into typed pydantic models — no ``getattr`` +
|
||||
``isinstance`` duck-typing at the call site.
|
||||
* :class:`BaselineReasoningEmitter` owns the reasoning block lifecycle for
|
||||
one streaming round and emits ``StreamReasoning*`` events so the caller
|
||||
only has to plumb the events into its pending queue.
|
||||
* :func:`reasoning_extra_body` builds the ``extra_body`` fragment for the
|
||||
OpenAI client call. Returns ``None`` for routes without reasoning
|
||||
support (see :func:`_is_reasoning_route`).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamReasoningDelta,
|
||||
StreamReasoningEnd,
|
||||
StreamReasoningStart,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_VISIBLE_REASONING_TYPES = frozenset({"reasoning.text", "reasoning.summary"})
|
||||
|
||||
# Coalescing thresholds for ``StreamReasoningDelta`` emission. OpenRouter's
|
||||
# Kimi K2.6 endpoint tokenises reasoning at a much finer grain than Anthropic
|
||||
# (~4,700 deltas per turn in one observed session, vs ~28 for Sonnet); without
|
||||
# coalescing, every chunk is one Redis ``xadd`` + one SSE frame + one React
|
||||
# re-render of the non-virtualised chat list, which paint-storms the browser
|
||||
# main thread and freezes the UI. Batching into ~64-char / ~50 ms windows
|
||||
# cuts the event rate ~150x while staying snappy enough that the Reasoning
|
||||
# collapse still feels live (well under the ~100 ms perceptual threshold).
|
||||
# Per-delta persistence to ``session.messages`` stays granular — we only
|
||||
# coalesce the *wire* emission.
|
||||
_COALESCE_MIN_CHARS = 64
|
||||
_COALESCE_MAX_INTERVAL_MS = 50.0
|
||||
|
||||
|
||||
class ReasoningDetail(BaseModel):
|
||||
"""One entry in OpenRouter's ``reasoning_details`` list.
|
||||
|
||||
OpenRouter ships ``type: "reasoning.text"`` / ``"reasoning.summary"`` /
|
||||
``"reasoning.encrypted"`` entries. Only the first two carry
|
||||
user-visible text; encrypted entries are opaque and omitted from the
|
||||
rendered collapse. Unknown future types are tolerated (``extra="ignore"``)
|
||||
so an upstream addition doesn't crash the stream — but their ``text`` /
|
||||
``summary`` fields are NOT surfaced because they may carry provider
|
||||
metadata rather than user-visible reasoning (see
|
||||
:attr:`visible_text`).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
type: str | None = None
|
||||
text: str | None = None
|
||||
summary: str | None = None
|
||||
|
||||
@property
|
||||
def visible_text(self) -> str:
|
||||
"""Return the human-readable text for this entry, or ``""``.
|
||||
|
||||
Only entries with a recognised reasoning type (``reasoning.text`` /
|
||||
``reasoning.summary``) surface text; unknown or encrypted types
|
||||
return an empty string even if they carry a ``text`` /
|
||||
``summary`` field, to guard against future provider metadata
|
||||
being rendered as reasoning in the UI. Entries missing a
|
||||
``type`` are treated as text (pre-``reasoning_details`` OpenRouter
|
||||
payloads omit the field).
|
||||
"""
|
||||
if self.type is not None and self.type not in _VISIBLE_REASONING_TYPES:
|
||||
return ""
|
||||
return self.text or self.summary or ""
|
||||
|
||||
|
||||
class OpenRouterDeltaExtension(BaseModel):
|
||||
"""Non-OpenAI fields OpenRouter adds to streaming deltas.
|
||||
|
||||
Instantiate via :meth:`from_delta` which pulls the extension dict off
|
||||
``ChoiceDelta.model_extra`` (where pydantic v2 stashes fields that
|
||||
aren't part of the declared schema) and validates it through this
|
||||
model. That keeps the parser honest — malformed entries surface as
|
||||
validation errors rather than silent ``None``-coalesce bugs — and
|
||||
avoids the ``getattr`` + ``isinstance`` duck-typing the earlier inline
|
||||
extractor relied on.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
reasoning: str | None = None
|
||||
reasoning_content: str | None = None
|
||||
reasoning_details: list[ReasoningDetail] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_delta(cls, delta: ChoiceDelta) -> "OpenRouterDeltaExtension":
|
||||
"""Build an extension view from ``delta.model_extra``.
|
||||
|
||||
Malformed provider payloads (e.g. ``reasoning_details`` shipped as
|
||||
a string rather than a list) surface as a ``ValidationError`` which
|
||||
is logged and swallowed — returning an empty extension so the rest
|
||||
of the stream (valid text / tool calls) keeps flowing. An optional
|
||||
feature's corrupted wire data must never abort the whole stream.
|
||||
"""
|
||||
try:
|
||||
return cls.model_validate(delta.model_extra or {})
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"[Baseline] Dropping malformed OpenRouter reasoning payload: %s",
|
||||
exc,
|
||||
)
|
||||
return cls()
|
||||
|
||||
def visible_text(self) -> str:
|
||||
"""Concatenated reasoning text, pulled from whichever channel is set.
|
||||
|
||||
Priority: the legacy ``reasoning`` string, then DeepSeek's
|
||||
``reasoning_content``, then the concatenation of text-bearing
|
||||
entries in ``reasoning_details``. Only one channel is set per
|
||||
provider in practice; the priority order just makes the fallback
|
||||
deterministic if a provider ever emits multiple.
|
||||
"""
|
||||
if self.reasoning:
|
||||
return self.reasoning
|
||||
if self.reasoning_content:
|
||||
return self.reasoning_content
|
||||
return "".join(d.visible_text for d in self.reasoning_details)
|
||||
|
||||
|
||||
def _is_reasoning_route(model: str) -> bool:
|
||||
"""Return True when the route supports OpenRouter's ``reasoning`` extension.
|
||||
|
||||
OpenRouter exposes reasoning tokens via a unified ``reasoning`` request
|
||||
param that works on any provider that supports extended thinking —
|
||||
currently Anthropic (Claude Opus / Sonnet) and Moonshot (Kimi K2.6 +
|
||||
kimi-k2-thinking) advertise it in their ``supported_parameters``.
|
||||
Other providers silently drop the field, but we skip it anyway to keep
|
||||
the payload tight and avoid confusing cache diagnostics.
|
||||
|
||||
Kept separate from :func:`backend.copilot.baseline.service._is_anthropic_model`
|
||||
because ``cache_control`` is strictly Anthropic-specific (Moonshot does
|
||||
its own auto-caching), so the two gates must not conflate.
|
||||
|
||||
Both the Claude and Kimi matches are anchored to the provider
|
||||
prefix (or to a bare model id with no prefix at all) to avoid
|
||||
substring false positives — a custom ``some-other-provider/claude-mock``
|
||||
or ``provider/hakimi-large`` configured via
|
||||
``CHAT_FAST_STANDARD_MODEL`` must NOT inherit the reasoning
|
||||
extra_body and take a 400 from its upstream. Recognised shapes:
|
||||
|
||||
* Claude — ``anthropic/`` or ``anthropic.`` provider prefix, or a
|
||||
bare ``claude-`` model id with no provider prefix
|
||||
(``claude-opus-4.7``, ``anthropic/claude-sonnet-4-6``,
|
||||
``anthropic.claude-3-5-sonnet``). A non-Anthropic prefix like
|
||||
``someprovider/claude-mock`` is rejected on purpose.
|
||||
* Kimi — ``moonshotai/`` provider prefix, or a ``kimi-`` model id
|
||||
with no provider prefix (``kimi-k2.6``,
|
||||
``moonshotai/kimi-k2-thinking``). Like Claude, a non-Moonshot
|
||||
prefix is rejected — exception: ``openrouter/kimi-k2.6`` stays
|
||||
recognised because ``openrouter/`` is how we route to Moonshot
|
||||
today and changing that would be a behaviour regression for
|
||||
existing deployments.
|
||||
"""
|
||||
lowered = model.lower()
|
||||
if lowered.startswith(("anthropic/", "anthropic.")):
|
||||
return True
|
||||
if lowered.startswith("moonshotai/"):
|
||||
return True
|
||||
# ``openrouter/`` historically routes to whatever the default
|
||||
# upstream for the model is — for kimi that's Moonshot, so accept
|
||||
# ``openrouter/kimi-...`` here. Other ``openrouter/`` models
|
||||
# (e.g. ``openrouter/auto``) fall through to the no-prefix check
|
||||
# below and are rejected unless they start with ``claude-`` /
|
||||
# ``kimi-`` after the slash, which no real OpenRouter route does.
|
||||
if lowered.startswith("openrouter/kimi-"):
|
||||
return True
|
||||
if "/" in lowered:
|
||||
# Any other provider prefix is a custom / non-Anthropic /
|
||||
# non-Moonshot route and must not opt into reasoning. This
|
||||
# blocks substring false positives like
|
||||
# ``some-provider/claude-mock-v1`` or ``other/kimi-pro``.
|
||||
return False
|
||||
# No provider prefix — accept bare ``claude-*`` and ``kimi-*`` ids
|
||||
# so direct CLI configs (``claude-3-5-sonnet-20241022``,
|
||||
# ``kimi-k2-instruct``) keep working.
|
||||
return lowered.startswith("claude-") or lowered.startswith("kimi-")
|
||||
|
||||
|
||||
def reasoning_extra_body(model: str, max_thinking_tokens: int) -> dict[str, Any] | None:
|
||||
"""Build the ``extra_body["reasoning"]`` fragment for the OpenAI client.
|
||||
|
||||
Returns ``None`` for non-reasoning routes and for
|
||||
``max_thinking_tokens <= 0`` (operator kill switch).
|
||||
"""
|
||||
if not _is_reasoning_route(model) or max_thinking_tokens <= 0:
|
||||
return None
|
||||
return {"reasoning": {"max_tokens": max_thinking_tokens}}
|
||||
|
||||
|
||||
class BaselineReasoningEmitter:
|
||||
"""Owns the reasoning block lifecycle for one streaming round.
|
||||
|
||||
Two concerns live here, both driven by the same state machine:
|
||||
|
||||
1. **Wire events.** The AI SDK v6 wire format pairs every
|
||||
``reasoning-start`` with a matching ``reasoning-end`` and treats
|
||||
reasoning / text / tool-use as distinct UI parts that must not
|
||||
interleave.
|
||||
2. **Session persistence.** ``ChatMessage(role="reasoning")`` rows in
|
||||
``session.messages`` are what
|
||||
``convertChatSessionToUiMessages.ts`` folds into the assistant
|
||||
bubble as ``{type: "reasoning"}`` UI parts on reload and on
|
||||
``useHydrateOnStreamEnd`` swaps. Without them the live-streamed
|
||||
reasoning parts get overwritten by the hydrated (reasoning-less)
|
||||
message list the moment the stream ends. Mirrors the SDK path's
|
||||
``acc.reasoning_response`` pattern so both routes render the same
|
||||
way on reload.
|
||||
|
||||
Pass ``session_messages`` to enable persistence; omit for pure
|
||||
wire-emission (tests, scratch callers). On first reasoning delta a
|
||||
fresh ``ChatMessage(role="reasoning")`` is appended and mutated
|
||||
in-place as further deltas arrive; :meth:`close` drops the reference
|
||||
but leaves the appended row intact.
|
||||
|
||||
``render_in_ui=False`` suppresses only the live wire events
|
||||
(``StreamReasoning*``); the ``role='reasoning'`` persistence row is
|
||||
still appended so ``convertChatSessionToUiMessages.ts`` can hydrate
|
||||
the reasoning bubble on reload. The state machine advances
|
||||
identically either way.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_messages: list[ChatMessage] | None = None,
|
||||
*,
|
||||
coalesce_min_chars: int = _COALESCE_MIN_CHARS,
|
||||
coalesce_max_interval_ms: float = _COALESCE_MAX_INTERVAL_MS,
|
||||
render_in_ui: bool = True,
|
||||
) -> None:
|
||||
self._block_id: str = str(uuid.uuid4())
|
||||
self._open: bool = False
|
||||
self._session_messages = session_messages
|
||||
self._current_row: ChatMessage | None = None
|
||||
# Coalescing state — tests can disable (``=0``) for deterministic
|
||||
# event assertions.
|
||||
self._coalesce_min_chars = coalesce_min_chars
|
||||
self._coalesce_max_interval_ms = coalesce_max_interval_ms
|
||||
self._pending_delta: str = ""
|
||||
self._last_flush_monotonic: float = 0.0
|
||||
self._render_in_ui = render_in_ui
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
return self._open
|
||||
|
||||
def on_delta(self, delta: ChoiceDelta) -> list[StreamBaseResponse]:
|
||||
"""Return events for the reasoning text carried by *delta*.
|
||||
|
||||
Empty list when the chunk carries no reasoning payload, so this is
|
||||
safe to call on every chunk without guarding at the call site.
|
||||
|
||||
Persistence (when a session message list is attached) stays
|
||||
per-delta so the DB row's content always equals the concatenation
|
||||
of wire deltas at every chunk boundary, independent of the
|
||||
coalescing window. Only the wire emission is batched.
|
||||
"""
|
||||
ext = OpenRouterDeltaExtension.from_delta(delta)
|
||||
text = ext.visible_text()
|
||||
if not text:
|
||||
return []
|
||||
events: list[StreamBaseResponse] = []
|
||||
# First reasoning text in this block — emit Start + the first Delta
|
||||
# atomically so the frontend Reasoning collapse renders immediately
|
||||
# rather than waiting for the coalesce window to elapse. Subsequent
|
||||
# chunks buffer into ``_pending_delta`` and only flush when the
|
||||
# char/time thresholds trip.
|
||||
# Sample the monotonic clock exactly once per chunk — at ~4,700
|
||||
# chunks per turn, folding the two calls into one cuts ~4,700
|
||||
# syscalls off the hot path without changing semantics.
|
||||
now = time.monotonic()
|
||||
if not self._open:
|
||||
if self._render_in_ui:
|
||||
events.append(StreamReasoningStart(id=self._block_id))
|
||||
events.append(StreamReasoningDelta(id=self._block_id, delta=text))
|
||||
self._open = True
|
||||
self._last_flush_monotonic = now
|
||||
if self._session_messages is not None:
|
||||
self._current_row = ChatMessage(role="reasoning", content=text)
|
||||
self._session_messages.append(self._current_row)
|
||||
return events
|
||||
|
||||
if self._current_row is not None:
|
||||
self._current_row.content = (self._current_row.content or "") + text
|
||||
|
||||
self._pending_delta += text
|
||||
if self._should_flush_pending(now):
|
||||
if self._render_in_ui:
|
||||
events.append(
|
||||
StreamReasoningDelta(id=self._block_id, delta=self._pending_delta)
|
||||
)
|
||||
self._pending_delta = ""
|
||||
self._last_flush_monotonic = now
|
||||
return events
|
||||
|
||||
def _should_flush_pending(self, now: float) -> bool:
|
||||
"""Return True when the accumulated delta should be emitted now.
|
||||
|
||||
*now* is the monotonic timestamp sampled by the caller so the
|
||||
clock is read at most once per chunk (the flush-timestamp update
|
||||
reuses the same value).
|
||||
"""
|
||||
if not self._pending_delta:
|
||||
return False
|
||||
if len(self._pending_delta) >= self._coalesce_min_chars:
|
||||
return True
|
||||
elapsed_ms = (now - self._last_flush_monotonic) * 1000.0
|
||||
return elapsed_ms >= self._coalesce_max_interval_ms
|
||||
|
||||
def close(self) -> list[StreamBaseResponse]:
|
||||
"""Emit ``StreamReasoningEnd`` for the open block (if any) and rotate.
|
||||
|
||||
Idempotent — returns ``[]`` when no block is open. Drains any
|
||||
still-buffered delta first so the frontend never loses tail text
|
||||
from the coalesce window. The id rotation guarantees the next
|
||||
reasoning block starts with a fresh id rather than reusing one
|
||||
already closed on the wire. The persisted row is not removed —
|
||||
it stays in ``session_messages`` as the durable record of what
|
||||
was reasoned.
|
||||
"""
|
||||
if not self._open:
|
||||
return []
|
||||
events: list[StreamBaseResponse] = []
|
||||
if self._render_in_ui:
|
||||
if self._pending_delta:
|
||||
events.append(
|
||||
StreamReasoningDelta(id=self._block_id, delta=self._pending_delta)
|
||||
)
|
||||
events.append(StreamReasoningEnd(id=self._block_id))
|
||||
self._pending_delta = ""
|
||||
self._open = False
|
||||
self._block_id = str(uuid.uuid4())
|
||||
self._current_row = None
|
||||
return events
|
||||
@@ -0,0 +1,514 @@
|
||||
"""Tests for the baseline reasoning extension module.
|
||||
|
||||
Covers the typed OpenRouter delta parser, the stateful emitter, and the
|
||||
``extra_body`` builder. The emitter is tested against real
|
||||
``ChoiceDelta`` pydantic instances so the ``model_extra`` plumbing the
|
||||
parser relies on is exercised end-to-end.
|
||||
"""
|
||||
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
|
||||
from backend.copilot.baseline.reasoning import (
|
||||
BaselineReasoningEmitter,
|
||||
OpenRouterDeltaExtension,
|
||||
ReasoningDetail,
|
||||
_is_reasoning_route,
|
||||
reasoning_extra_body,
|
||||
)
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.response_model import (
|
||||
StreamReasoningDelta,
|
||||
StreamReasoningEnd,
|
||||
StreamReasoningStart,
|
||||
)
|
||||
|
||||
|
||||
def _delta(**extra) -> ChoiceDelta:
|
||||
"""Build a ChoiceDelta with the given extension fields on ``model_extra``."""
|
||||
return ChoiceDelta.model_validate({"role": "assistant", **extra})
|
||||
|
||||
|
||||
class TestReasoningDetail:
|
||||
def test_visible_text_prefers_text(self):
|
||||
d = ReasoningDetail(type="reasoning.text", text="hi", summary="ignored")
|
||||
assert d.visible_text == "hi"
|
||||
|
||||
def test_visible_text_falls_back_to_summary(self):
|
||||
d = ReasoningDetail(type="reasoning.summary", summary="tldr")
|
||||
assert d.visible_text == "tldr"
|
||||
|
||||
def test_visible_text_empty_for_encrypted(self):
|
||||
d = ReasoningDetail(type="reasoning.encrypted")
|
||||
assert d.visible_text == ""
|
||||
|
||||
def test_unknown_fields_are_ignored(self):
|
||||
# OpenRouter may add new fields in future payloads — they shouldn't
|
||||
# cause validation errors.
|
||||
d = ReasoningDetail.model_validate(
|
||||
{"type": "reasoning.future", "text": "x", "signature": "opaque"}
|
||||
)
|
||||
assert d.text == "x"
|
||||
|
||||
def test_visible_text_empty_for_unknown_type(self):
|
||||
# Unknown types may carry provider metadata that must not render as
|
||||
# user-visible reasoning — regardless of whether a text/summary is
|
||||
# present. Only ``reasoning.text`` / ``reasoning.summary`` surface.
|
||||
d = ReasoningDetail(type="reasoning.future", text="leaked metadata")
|
||||
assert d.visible_text == ""
|
||||
|
||||
def test_visible_text_surfaces_text_when_type_missing(self):
|
||||
# Pre-``reasoning_details`` OpenRouter payloads omit ``type`` — treat
|
||||
# them as text so we don't regress the legacy structured shape.
|
||||
d = ReasoningDetail(text="plain")
|
||||
assert d.visible_text == "plain"
|
||||
|
||||
|
||||
class TestOpenRouterDeltaExtension:
|
||||
def test_from_delta_reads_model_extra(self):
|
||||
delta = _delta(reasoning="step one")
|
||||
ext = OpenRouterDeltaExtension.from_delta(delta)
|
||||
assert ext.reasoning == "step one"
|
||||
|
||||
def test_visible_text_legacy_string(self):
|
||||
ext = OpenRouterDeltaExtension(reasoning="plain text")
|
||||
assert ext.visible_text() == "plain text"
|
||||
|
||||
def test_visible_text_deepseek_alias(self):
|
||||
ext = OpenRouterDeltaExtension(reasoning_content="alt channel")
|
||||
assert ext.visible_text() == "alt channel"
|
||||
|
||||
def test_visible_text_structured_details_concat(self):
|
||||
ext = OpenRouterDeltaExtension(
|
||||
reasoning_details=[
|
||||
ReasoningDetail(type="reasoning.text", text="hello "),
|
||||
ReasoningDetail(type="reasoning.text", text="world"),
|
||||
]
|
||||
)
|
||||
assert ext.visible_text() == "hello world"
|
||||
|
||||
def test_visible_text_skips_encrypted(self):
|
||||
ext = OpenRouterDeltaExtension(
|
||||
reasoning_details=[
|
||||
ReasoningDetail(type="reasoning.encrypted"),
|
||||
ReasoningDetail(type="reasoning.text", text="visible"),
|
||||
]
|
||||
)
|
||||
assert ext.visible_text() == "visible"
|
||||
|
||||
def test_visible_text_empty_when_all_channels_blank(self):
|
||||
ext = OpenRouterDeltaExtension()
|
||||
assert ext.visible_text() == ""
|
||||
|
||||
def test_empty_delta_produces_empty_extension(self):
|
||||
ext = OpenRouterDeltaExtension.from_delta(_delta())
|
||||
assert ext.reasoning is None
|
||||
assert ext.reasoning_content is None
|
||||
assert ext.reasoning_details == []
|
||||
|
||||
def test_malformed_reasoning_payload_logged_and_swallowed(self, caplog):
|
||||
# A malformed payload (e.g. reasoning_details shipped as a string
|
||||
# rather than a list) must not abort the stream — log it and
|
||||
# return an empty extension so valid text/tool events keep flowing.
|
||||
# A plain mock is used here because ``from_delta`` only reads
|
||||
# ``delta.model_extra`` — avoids reaching into pydantic internals
|
||||
# (``__pydantic_extra__``) that could be renamed across versions.
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
delta = MagicMock(spec=ChoiceDelta)
|
||||
delta.model_extra = {"reasoning_details": "not a list"}
|
||||
with caplog.at_level("WARNING"):
|
||||
ext = OpenRouterDeltaExtension.from_delta(delta)
|
||||
assert ext.reasoning_details == []
|
||||
assert ext.visible_text() == ""
|
||||
assert any("malformed" in r.message.lower() for r in caplog.records)
|
||||
|
||||
def test_unknown_typed_entry_with_text_is_not_surfaced(self):
|
||||
# Regression: the legacy extractor emitted any entry with a
|
||||
# ``text`` or ``summary`` field. The typed parser now filters on
|
||||
# the recognised types so future provider metadata can't leak
|
||||
# into the reasoning collapse.
|
||||
ext = OpenRouterDeltaExtension(
|
||||
reasoning_details=[
|
||||
ReasoningDetail(type="reasoning.future", text="provider metadata"),
|
||||
ReasoningDetail(type="reasoning.text", text="real"),
|
||||
]
|
||||
)
|
||||
assert ext.visible_text() == "real"
|
||||
|
||||
|
||||
class TestIsReasoningRoute:
|
||||
def test_anthropic_routes(self):
|
||||
assert _is_reasoning_route("anthropic/claude-sonnet-4-6")
|
||||
assert _is_reasoning_route("claude-3-5-sonnet-20241022")
|
||||
assert _is_reasoning_route("anthropic.claude-3-5-sonnet")
|
||||
assert _is_reasoning_route("ANTHROPIC/Claude-Opus") # case-insensitive
|
||||
|
||||
def test_moonshot_kimi_routes(self):
|
||||
# OpenRouter advertises the ``reasoning`` extension on Moonshot
|
||||
# endpoints — both K2.6 (the new baseline default) and the
|
||||
# reasoning-native kimi-k2-thinking variant.
|
||||
assert _is_reasoning_route("moonshotai/kimi-k2.6")
|
||||
assert _is_reasoning_route("moonshotai/kimi-k2-thinking")
|
||||
assert _is_reasoning_route("moonshotai/kimi-k2.5")
|
||||
# Direct (non-OpenRouter) model ids also resolve via the ``kimi-``
|
||||
# prefix so a future bare ``kimi-k3`` id would still match.
|
||||
assert _is_reasoning_route("kimi-k2-instruct")
|
||||
# Provider-prefixed bare kimi ids (without the ``moonshotai/``
|
||||
# prefix) are also recognised — the match anchors on the final
|
||||
# path segment.
|
||||
assert _is_reasoning_route("openrouter/kimi-k2.6")
|
||||
|
||||
def test_other_providers_rejected(self):
|
||||
assert not _is_reasoning_route("openai/gpt-4o")
|
||||
assert not _is_reasoning_route("google/gemini-2.5-pro")
|
||||
assert not _is_reasoning_route("xai/grok-4")
|
||||
assert not _is_reasoning_route("meta-llama/llama-3.3-70b-instruct")
|
||||
assert not _is_reasoning_route("deepseek/deepseek-r1")
|
||||
|
||||
def test_kimi_substring_false_positives_rejected(self):
|
||||
# Regression: the previous implementation matched any model whose
|
||||
# name contained the substring ``kimi`` — including unrelated model
|
||||
# ids like ``hakimi``. The anchored match below rejects them.
|
||||
assert not _is_reasoning_route("some-provider/hakimi-large")
|
||||
assert not _is_reasoning_route("hakimi")
|
||||
assert not _is_reasoning_route("akimi-7b")
|
||||
|
||||
def test_claude_substring_false_positives_rejected(self):
|
||||
# Regression (Sentry review on #12871): ``'claude' in lowered``
|
||||
# matched any substring — a custom
|
||||
# ``someprovider/claude-mock-v1`` set via
|
||||
# ``CHAT_FAST_STANDARD_MODEL`` would inherit the reasoning
|
||||
# extra_body and take a 400 from its upstream. The anchored
|
||||
# match requires either an ``anthropic`` / ``anthropic.`` /
|
||||
# ``anthropic/`` prefix, or a bare ``claude-`` id with no
|
||||
# provider prefix.
|
||||
assert not _is_reasoning_route("someprovider/claude-mock-v1")
|
||||
assert not _is_reasoning_route("custom/claude-like-model")
|
||||
# Same principle for Kimi — a non-Moonshot provider prefix is
|
||||
# rejected even when the model id starts with ``kimi-``.
|
||||
assert not _is_reasoning_route("other/kimi-pro")
|
||||
|
||||
|
||||
class TestReasoningExtraBody:
|
||||
def test_anthropic_route_returns_fragment(self):
|
||||
assert reasoning_extra_body("anthropic/claude-sonnet-4-6", 4096) == {
|
||||
"reasoning": {"max_tokens": 4096}
|
||||
}
|
||||
|
||||
def test_direct_claude_model_id_still_matches(self):
|
||||
assert reasoning_extra_body("claude-3-5-sonnet-20241022", 2048) == {
|
||||
"reasoning": {"max_tokens": 2048}
|
||||
}
|
||||
|
||||
def test_kimi_routes_return_fragment(self):
|
||||
# Kimi K2.6 ships the same OpenRouter ``reasoning`` extension as
|
||||
# Anthropic, so the gate widened with this PR and the fragment
|
||||
# must now materialise on Moonshot routes too.
|
||||
assert reasoning_extra_body("moonshotai/kimi-k2.6", 8192) == {
|
||||
"reasoning": {"max_tokens": 8192}
|
||||
}
|
||||
assert reasoning_extra_body("moonshotai/kimi-k2-thinking", 4096) == {
|
||||
"reasoning": {"max_tokens": 4096}
|
||||
}
|
||||
|
||||
def test_non_reasoning_route_returns_none(self):
|
||||
assert reasoning_extra_body("openai/gpt-4o", 4096) is None
|
||||
assert reasoning_extra_body("google/gemini-2.5-pro", 4096) is None
|
||||
assert reasoning_extra_body("xai/grok-4", 4096) is None
|
||||
|
||||
def test_zero_max_tokens_kill_switch(self):
|
||||
# Operator kill switch: ``max_thinking_tokens <= 0`` disables the
|
||||
# ``reasoning`` extra_body fragment on ANY reasoning route (Anthropic
|
||||
# or Kimi). Lets us silence reasoning without dropping the SDK
|
||||
# path's budget.
|
||||
assert reasoning_extra_body("anthropic/claude-sonnet-4-6", 0) is None
|
||||
assert reasoning_extra_body("anthropic/claude-sonnet-4-6", -1) is None
|
||||
assert reasoning_extra_body("moonshotai/kimi-k2.6", 0) is None
|
||||
|
||||
|
||||
class TestBaselineReasoningEmitter:
|
||||
def test_first_text_delta_emits_start_then_delta(self):
|
||||
emitter = BaselineReasoningEmitter()
|
||||
events = emitter.on_delta(_delta(reasoning="thinking"))
|
||||
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], StreamReasoningStart)
|
||||
assert isinstance(events[1], StreamReasoningDelta)
|
||||
assert events[0].id == events[1].id
|
||||
assert events[1].delta == "thinking"
|
||||
assert emitter.is_open is True
|
||||
|
||||
def test_subsequent_deltas_reuse_block_id_without_new_start(self):
|
||||
# Disable coalescing so each chunk flushes immediately — this test
|
||||
# is about the Start/Delta/block-id state machine, not the coalesce
|
||||
# window. Coalescing behaviour is covered below.
|
||||
emitter = BaselineReasoningEmitter(
|
||||
coalesce_min_chars=0, coalesce_max_interval_ms=0
|
||||
)
|
||||
first = emitter.on_delta(_delta(reasoning="a"))
|
||||
second = emitter.on_delta(_delta(reasoning="b"))
|
||||
|
||||
assert any(isinstance(e, StreamReasoningStart) for e in first)
|
||||
assert all(not isinstance(e, StreamReasoningStart) for e in second)
|
||||
assert len(second) == 1
|
||||
assert isinstance(second[0], StreamReasoningDelta)
|
||||
assert first[0].id == second[0].id
|
||||
|
||||
def test_empty_delta_emits_nothing(self):
|
||||
emitter = BaselineReasoningEmitter()
|
||||
assert emitter.on_delta(_delta(content="hello")) == []
|
||||
assert emitter.is_open is False
|
||||
|
||||
def test_close_emits_end_and_rotates_id(self):
|
||||
emitter = BaselineReasoningEmitter()
|
||||
# Capture the block id from the wire event rather than reaching
|
||||
# into emitter internals — the id on the emitted Start/Delta is
|
||||
# what the frontend actually receives.
|
||||
start_events = emitter.on_delta(_delta(reasoning="x"))
|
||||
first_id = start_events[0].id
|
||||
|
||||
events = emitter.close()
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], StreamReasoningEnd)
|
||||
assert events[0].id == first_id
|
||||
assert emitter.is_open is False
|
||||
# Next reasoning uses a fresh id.
|
||||
new_events = emitter.on_delta(_delta(reasoning="y"))
|
||||
assert isinstance(new_events[0], StreamReasoningStart)
|
||||
assert new_events[0].id != first_id
|
||||
|
||||
def test_close_is_idempotent(self):
|
||||
emitter = BaselineReasoningEmitter()
|
||||
assert emitter.close() == []
|
||||
emitter.on_delta(_delta(reasoning="x"))
|
||||
assert len(emitter.close()) == 1
|
||||
assert emitter.close() == []
|
||||
|
||||
def test_structured_details_round_trip(self):
|
||||
emitter = BaselineReasoningEmitter()
|
||||
events = emitter.on_delta(
|
||||
_delta(
|
||||
reasoning_details=[
|
||||
{"type": "reasoning.text", "text": "plan: "},
|
||||
{"type": "reasoning.summary", "summary": "do the thing"},
|
||||
]
|
||||
)
|
||||
)
|
||||
deltas = [e for e in events if isinstance(e, StreamReasoningDelta)]
|
||||
assert len(deltas) == 1
|
||||
assert deltas[0].delta == "plan: do the thing"
|
||||
|
||||
|
||||
class TestReasoningDeltaCoalescing:
|
||||
"""Coalescing batches fine-grained provider chunks into bigger wire
|
||||
frames. OpenRouter's Kimi K2.6 emits ~4,700 reasoning-delta chunks
|
||||
per turn vs ~28 for Sonnet; without batching, every chunk becomes one
|
||||
Redis ``xadd`` + one SSE event + one React re-render of the
|
||||
non-virtualised chat list, which paint-storms the browser. These
|
||||
tests pin the batching contract: small chunks buffer until the
|
||||
char-size or time threshold trips, large chunks still flush
|
||||
immediately, and ``close()`` never drops tail text."""
|
||||
|
||||
def test_small_chunks_after_first_buffer_until_threshold(self):
|
||||
# Generous time threshold so size alone controls flush timing.
|
||||
emitter = BaselineReasoningEmitter(
|
||||
coalesce_min_chars=32, coalesce_max_interval_ms=60_000
|
||||
)
|
||||
# First chunk always flushes immediately (so UI renders without
|
||||
# waiting).
|
||||
first = emitter.on_delta(_delta(reasoning="hi "))
|
||||
assert any(isinstance(e, StreamReasoningStart) for e in first)
|
||||
assert sum(isinstance(e, StreamReasoningDelta) for e in first) == 1
|
||||
|
||||
# Subsequent small chunks buffer silently — 5 × 4 chars = 20 chars,
|
||||
# still under the 32-char threshold.
|
||||
for _ in range(5):
|
||||
assert emitter.on_delta(_delta(reasoning="abcd")) == []
|
||||
|
||||
# Once the threshold is crossed, the accumulated buffer flushes
|
||||
# as a single StreamReasoningDelta carrying every buffered chunk.
|
||||
flush = emitter.on_delta(_delta(reasoning="efghijklmnop"))
|
||||
assert len(flush) == 1
|
||||
assert isinstance(flush[0], StreamReasoningDelta)
|
||||
assert flush[0].delta == "abcd" * 5 + "efghijklmnop"
|
||||
|
||||
def test_time_based_flush_when_chars_stay_below_threshold(self, monkeypatch):
|
||||
# Fake ``time.monotonic`` so we can drive the time-based branch
|
||||
# deterministically without real sleeps.
|
||||
from backend.copilot.baseline import reasoning as rmod
|
||||
|
||||
fake_now = [0.0]
|
||||
monkeypatch.setattr(rmod.time, "monotonic", lambda: fake_now[0])
|
||||
|
||||
emitter = BaselineReasoningEmitter(
|
||||
coalesce_min_chars=1000, coalesce_max_interval_ms=40
|
||||
)
|
||||
# t=0: first chunk flushes immediately.
|
||||
first = emitter.on_delta(_delta(reasoning="a"))
|
||||
assert sum(isinstance(e, StreamReasoningDelta) for e in first) == 1
|
||||
|
||||
# t=10 ms: still under 40 ms → buffer.
|
||||
fake_now[0] = 0.010
|
||||
assert emitter.on_delta(_delta(reasoning="b")) == []
|
||||
|
||||
# t=50 ms since last flush → time threshold trips, flush fires.
|
||||
fake_now[0] = 0.060
|
||||
flushed = emitter.on_delta(_delta(reasoning="c"))
|
||||
assert len(flushed) == 1
|
||||
assert isinstance(flushed[0], StreamReasoningDelta)
|
||||
assert flushed[0].delta == "bc"
|
||||
|
||||
def test_close_flushes_tail_buffer_before_end(self):
|
||||
emitter = BaselineReasoningEmitter(
|
||||
coalesce_min_chars=1000, coalesce_max_interval_ms=60_000
|
||||
)
|
||||
emitter.on_delta(_delta(reasoning="first")) # flushes (first chunk)
|
||||
emitter.on_delta(_delta(reasoning=" middle ")) # buffered
|
||||
emitter.on_delta(_delta(reasoning="tail")) # buffered
|
||||
|
||||
events = emitter.close()
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], StreamReasoningDelta)
|
||||
assert events[0].delta == " middle tail"
|
||||
assert isinstance(events[1], StreamReasoningEnd)
|
||||
|
||||
def test_coalesce_disabled_flushes_every_chunk(self):
|
||||
emitter = BaselineReasoningEmitter(
|
||||
coalesce_min_chars=0, coalesce_max_interval_ms=0
|
||||
)
|
||||
first = emitter.on_delta(_delta(reasoning="a"))
|
||||
second = emitter.on_delta(_delta(reasoning="b"))
|
||||
assert sum(isinstance(e, StreamReasoningDelta) for e in first) == 1
|
||||
assert sum(isinstance(e, StreamReasoningDelta) for e in second) == 1
|
||||
|
||||
def test_persistence_stays_per_delta_even_when_wire_coalesces(self):
|
||||
"""DB row content must track every chunk so a crash mid-turn
|
||||
persists the full reasoning-so-far, even if the coalesce window
|
||||
never flushed those chunks to the wire."""
|
||||
session: list[ChatMessage] = []
|
||||
emitter = BaselineReasoningEmitter(
|
||||
session,
|
||||
coalesce_min_chars=1000,
|
||||
coalesce_max_interval_ms=60_000,
|
||||
)
|
||||
emitter.on_delta(_delta(reasoning="first "))
|
||||
emitter.on_delta(_delta(reasoning="chunk "))
|
||||
emitter.on_delta(_delta(reasoning="three"))
|
||||
# No close; verify the persisted row already has everything.
|
||||
assert len(session) == 1
|
||||
assert session[0].content == "first chunk three"
|
||||
|
||||
|
||||
class TestReasoningPersistence:
|
||||
"""The persistence contract: without ``role="reasoning"`` rows in
|
||||
session.messages, useHydrateOnStreamEnd overwrites the live-streamed
|
||||
reasoning parts and the Reasoning collapse vanishes. Every delta
|
||||
must be reflected in the persisted row the moment it's emitted."""
|
||||
|
||||
def test_session_row_appended_on_first_delta(self):
|
||||
session: list[ChatMessage] = []
|
||||
emitter = BaselineReasoningEmitter(session)
|
||||
|
||||
assert session == []
|
||||
emitter.on_delta(_delta(reasoning="hi"))
|
||||
assert len(session) == 1
|
||||
assert session[0].role == "reasoning"
|
||||
assert session[0].content == "hi"
|
||||
|
||||
def test_subsequent_deltas_mutate_same_row(self):
|
||||
session: list[ChatMessage] = []
|
||||
emitter = BaselineReasoningEmitter(session)
|
||||
|
||||
emitter.on_delta(_delta(reasoning="part one "))
|
||||
emitter.on_delta(_delta(reasoning="part two"))
|
||||
|
||||
assert len(session) == 1
|
||||
assert session[0].content == "part one part two"
|
||||
|
||||
def test_close_keeps_row_in_session(self):
|
||||
session: list[ChatMessage] = []
|
||||
emitter = BaselineReasoningEmitter(session)
|
||||
|
||||
emitter.on_delta(_delta(reasoning="thought"))
|
||||
emitter.close()
|
||||
|
||||
assert len(session) == 1
|
||||
assert session[0].content == "thought"
|
||||
|
||||
def test_second_reasoning_block_appends_new_row(self):
|
||||
session: list[ChatMessage] = []
|
||||
emitter = BaselineReasoningEmitter(session)
|
||||
|
||||
emitter.on_delta(_delta(reasoning="first"))
|
||||
emitter.close()
|
||||
emitter.on_delta(_delta(reasoning="second"))
|
||||
|
||||
assert len(session) == 2
|
||||
assert [m.content for m in session] == ["first", "second"]
|
||||
|
||||
def test_no_session_means_no_persistence(self):
|
||||
"""Emitter without attached session list emits wire events only."""
|
||||
emitter = BaselineReasoningEmitter()
|
||||
events = emitter.on_delta(_delta(reasoning="pure wire"))
|
||||
assert len(events) == 2 # start + delta, no crash
|
||||
# Nothing else to assert — just proves None session is supported.
|
||||
|
||||
|
||||
class TestBaselineReasoningEmitterRenderFlag:
|
||||
"""``render_in_ui=False`` must silence ``StreamReasoning*`` wire events
|
||||
AND drop persistence of ``role="reasoning"`` rows — the operator hides
|
||||
the collapse on both the live wire and on reload. Persistence is tied
|
||||
to the wire events because the frontend's hydration path unconditionally
|
||||
re-renders persisted reasoning rows; keeping them would make the flag a
|
||||
no-op post-reload. These tests pin the contract in both directions so
|
||||
future refactors can't flip only one half."""
|
||||
|
||||
def test_render_off_suppresses_start_and_delta(self):
|
||||
emitter = BaselineReasoningEmitter(render_in_ui=False)
|
||||
events = emitter.on_delta(_delta(reasoning="hidden"))
|
||||
# No wire events, but state advanced (is_open == True) so close()
|
||||
# below has something to rotate.
|
||||
assert events == []
|
||||
assert emitter.is_open is True
|
||||
|
||||
def test_render_off_suppresses_close_end(self):
|
||||
emitter = BaselineReasoningEmitter(render_in_ui=False)
|
||||
emitter.on_delta(_delta(reasoning="hidden"))
|
||||
events = emitter.close()
|
||||
assert events == []
|
||||
assert emitter.is_open is False
|
||||
|
||||
def test_render_off_still_persists(self):
|
||||
"""Persistence is decoupled from the render flag — session
|
||||
transcript always keeps the ``role="reasoning"`` row so audit
|
||||
and ``--resume``-equivalent replay never lose thinking text.
|
||||
The frontend gates rendering separately."""
|
||||
session: list[ChatMessage] = []
|
||||
emitter = BaselineReasoningEmitter(session, render_in_ui=False)
|
||||
|
||||
emitter.on_delta(_delta(reasoning="part one "))
|
||||
emitter.on_delta(_delta(reasoning="part two"))
|
||||
emitter.close()
|
||||
|
||||
assert len(session) == 1
|
||||
assert session[0].role == "reasoning"
|
||||
assert session[0].content == "part one part two"
|
||||
|
||||
def test_render_off_rotates_block_id_between_sessions(self):
|
||||
"""Even with wire events silenced the block id must rotate on close,
|
||||
otherwise a hypothetical mid-session flip would reuse a stale id."""
|
||||
emitter = BaselineReasoningEmitter(render_in_ui=False)
|
||||
emitter.on_delta(_delta(reasoning="first"))
|
||||
first_block_id = emitter._block_id
|
||||
emitter.close()
|
||||
emitter.on_delta(_delta(reasoning="second"))
|
||||
assert emitter._block_id != first_block_id
|
||||
|
||||
def test_render_on_is_default(self):
|
||||
"""Defaulting to True preserves backward compat — existing callers
|
||||
that don't pass the kwarg keep emitting wire events as before."""
|
||||
emitter = BaselineReasoningEmitter()
|
||||
events = emitter.on_delta(_delta(reasoning="hello"))
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], StreamReasoningStart)
|
||||
assert isinstance(events[1], StreamReasoningDelta)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -63,21 +63,117 @@ def _make_session_messages(*roles: str) -> list[ChatMessage]:
|
||||
|
||||
|
||||
class TestResolveBaselineModel:
|
||||
"""Baseline model resolution honours the per-request tier toggle."""
|
||||
"""Baseline model resolution honours the per-request tier toggle.
|
||||
|
||||
def test_advanced_tier_selects_advanced_model(self):
|
||||
assert _resolve_baseline_model("advanced") == config.advanced_model
|
||||
Baseline reads the ``fast_*_model`` cells of the (path, tier) matrix
|
||||
and never falls through to the SDK-side ``thinking_*_model`` cells.
|
||||
Without a user_id (so no LD context) the resolver returns the
|
||||
``ChatConfig`` static default; per-user overrides are exercised in
|
||||
``copilot/model_router_test.py``.
|
||||
"""
|
||||
|
||||
def test_standard_tier_selects_default_model(self):
|
||||
assert _resolve_baseline_model("standard") == config.model
|
||||
@pytest.mark.asyncio
|
||||
async def test_advanced_tier_selects_fast_advanced_model(self):
|
||||
assert (
|
||||
await _resolve_baseline_model("advanced", None)
|
||||
== config.fast_advanced_model
|
||||
)
|
||||
|
||||
def test_none_tier_selects_default_model(self):
|
||||
"""Baseline users without a tier MUST keep the default (standard)."""
|
||||
assert _resolve_baseline_model(None) == config.model
|
||||
@pytest.mark.asyncio
|
||||
async def test_standard_tier_selects_fast_standard_model(self):
|
||||
assert (
|
||||
await _resolve_baseline_model("standard", None)
|
||||
== config.fast_standard_model
|
||||
)
|
||||
|
||||
def test_standard_and_advanced_models_differ(self):
|
||||
"""Advanced tier defaults to a different (Opus) model than standard."""
|
||||
assert config.model != config.advanced_model
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_tier_selects_fast_standard_model(self):
|
||||
"""Baseline users without a tier get the fast-standard default."""
|
||||
assert await _resolve_baseline_model(None, None) == config.fast_standard_model
|
||||
|
||||
def test_fast_standard_default_is_sonnet(self):
|
||||
"""Shipped default: Sonnet on the baseline standard cell — the
|
||||
non-Anthropic routes ship via the LD flag instead of a config
|
||||
change. Asserts the declared ``Field`` default so a deploy-time
|
||||
``CHAT_FAST_STANDARD_MODEL`` override doesn't flake CI."""
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
assert (
|
||||
ChatConfig.model_fields["fast_standard_model"].default
|
||||
== "anthropic/claude-sonnet-4-6"
|
||||
)
|
||||
|
||||
def test_fast_advanced_default_is_opus(self):
|
||||
"""Shipped default: Opus on the baseline advanced cell — mirrors
|
||||
the SDK advanced cell so the advanced-tier A/B stays clean
|
||||
(same model, different path)."""
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
assert (
|
||||
ChatConfig.model_fields["fast_advanced_model"].default
|
||||
== "anthropic/claude-opus-4.7"
|
||||
)
|
||||
|
||||
def test_standard_and_advanced_cells_differ_on_fast(self):
|
||||
"""Advanced tier defaults to a different model than standard on
|
||||
the baseline path. Checked against declared ``Field`` defaults
|
||||
so operator env overrides don't flake the test."""
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
assert (
|
||||
ChatConfig.model_fields["fast_standard_model"].default
|
||||
!= ChatConfig.model_fields["fast_advanced_model"].default
|
||||
)
|
||||
|
||||
def test_legacy_env_aliases_route_to_new_fields(self, monkeypatch):
|
||||
"""Backward compat: the pre-split env var names must still bind.
|
||||
|
||||
The four-field matrix was introduced with ``validation_alias``
|
||||
entries so that existing deployments setting ``CHAT_MODEL`` /
|
||||
``CHAT_ADVANCED_MODEL`` / ``CHAT_FAST_MODEL`` continue to override
|
||||
the same effective cell without a rename. Construct a fresh
|
||||
``ChatConfig`` with each legacy name set and confirm it lands on
|
||||
the new field.
|
||||
"""
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
monkeypatch.setenv("CHAT_MODEL", "legacy/sonnet-via-chat-model")
|
||||
monkeypatch.setenv("CHAT_ADVANCED_MODEL", "legacy/opus-via-advanced")
|
||||
monkeypatch.setenv("CHAT_FAST_MODEL", "legacy/fast-via-fast-model")
|
||||
|
||||
cfg = ChatConfig()
|
||||
|
||||
assert cfg.thinking_standard_model == "legacy/sonnet-via-chat-model"
|
||||
assert cfg.thinking_advanced_model == "legacy/opus-via-advanced"
|
||||
assert cfg.fast_standard_model == "legacy/fast-via-fast-model"
|
||||
|
||||
def test_all_four_new_env_vars_bind_to_their_cells(self, monkeypatch):
|
||||
"""Each of the four (path, tier) cells must be overridable via
|
||||
its documented ``CHAT_*_*_MODEL`` env var — including
|
||||
``CHAT_FAST_ADVANCED_MODEL`` which was missing a
|
||||
``validation_alias`` in the original split and only bound
|
||||
implicitly through ``env_prefix``. Pinning all four here so
|
||||
that whenever someone touches the config shape, an accidental
|
||||
unbinding fails CI instead of silently ignoring operator
|
||||
overrides.
|
||||
"""
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
monkeypatch.setenv("CHAT_FAST_STANDARD_MODEL", "explicit/fast-std")
|
||||
monkeypatch.setenv("CHAT_FAST_ADVANCED_MODEL", "explicit/fast-adv")
|
||||
monkeypatch.setenv("CHAT_THINKING_STANDARD_MODEL", "explicit/think-std")
|
||||
monkeypatch.setenv("CHAT_THINKING_ADVANCED_MODEL", "explicit/think-adv")
|
||||
# Clear the legacy aliases so they don't win priority in
|
||||
# ``AliasChoices`` (first match wins).
|
||||
for legacy in ("CHAT_MODEL", "CHAT_ADVANCED_MODEL", "CHAT_FAST_MODEL"):
|
||||
monkeypatch.delenv(legacy, raising=False)
|
||||
|
||||
cfg = ChatConfig()
|
||||
|
||||
assert cfg.fast_standard_model == "explicit/fast-std"
|
||||
assert cfg.fast_advanced_model == "explicit/fast-adv"
|
||||
assert cfg.thinking_standard_model == "explicit/think-std"
|
||||
assert cfg.thinking_advanced_model == "explicit/think-adv"
|
||||
|
||||
|
||||
class TestLoadPriorTranscript:
|
||||
|
||||
217
autogpt_platform/backend/backend/copilot/builder_context.py
Normal file
217
autogpt_platform/backend/backend/copilot/builder_context.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""Builder-session context helpers — split cacheable system prompt from
|
||||
the volatile per-turn snapshot so Claude's prompt cache stays warm."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
from backend.copilot.tools.agent_generator import get_agent_as_json
|
||||
from backend.copilot.tools.get_agent_building_guide import _load_guide
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
BUILDER_CONTEXT_TAG = "builder_context"
|
||||
BUILDER_SESSION_TAG = "builder_session"
|
||||
|
||||
|
||||
# Tools hidden from builder-bound sessions: ``create_agent`` /
|
||||
# ``customize_agent`` would mint a new graph (panel is bound to one),
|
||||
# and ``get_agent_building_guide`` duplicates bytes already in the
|
||||
# system-prompt suffix. Everything else (find_block, find_agent, …)
|
||||
# stays available so the LLM can look up ids instead of hallucinating.
|
||||
BUILDER_BLOCKED_TOOLS: tuple[str, ...] = (
|
||||
"create_agent",
|
||||
"customize_agent",
|
||||
"get_agent_building_guide",
|
||||
)
|
||||
|
||||
|
||||
def resolve_session_permissions(
|
||||
session: ChatSession | None,
|
||||
) -> CopilotPermissions | None:
|
||||
"""Blacklist :data:`BUILDER_BLOCKED_TOOLS` for builder-bound sessions,
|
||||
return ``None`` (unrestricted) otherwise."""
|
||||
if session is None or not session.metadata.builder_graph_id:
|
||||
return None
|
||||
return CopilotPermissions(
|
||||
tools=list(BUILDER_BLOCKED_TOOLS),
|
||||
tools_exclude=True,
|
||||
)
|
||||
|
||||
|
||||
# Caps — mirror the frontend ``serializeGraphForChat`` defaults so the
|
||||
# server-side block stays within a practical token budget for large graphs.
|
||||
_MAX_NODES = 100
|
||||
_MAX_LINKS = 200
|
||||
|
||||
_FETCH_FAILED_PREFIX = (
|
||||
f"<{BUILDER_CONTEXT_TAG}>\n"
|
||||
f"<status>fetch_failed</status>\n"
|
||||
f"</{BUILDER_CONTEXT_TAG}>\n\n"
|
||||
)
|
||||
|
||||
# Embedded in the cacheable suffix so the LLM picks the right run_agent
|
||||
# dispatch mode without forcing the user to watch a long-blocking call.
|
||||
_BUILDER_RUN_AGENT_GUIDANCE = (
|
||||
"You are operating inside the builder panel, not the standalone "
|
||||
"copilot page. The builder page already subscribes to agent "
|
||||
"executions the moment you return an execution_id, so for REAL "
|
||||
"(non-dry) runs prefer `run_agent(dry_run=False, wait_for_result=0)` "
|
||||
"— the user will see the run stream in the builder's execution panel "
|
||||
"in-place and your turn ends immediately with the id. For DRY-RUNS "
|
||||
"keep `dry_run=True, wait_for_result=120`: blocking is required so "
|
||||
"you can inspect `execution.node_executions` and report the verdict "
|
||||
"in the same turn."
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_for_xml(value: Any) -> str:
|
||||
"""Escape XML special chars — mirrors ``sanitizeForXml`` in
|
||||
``BuilderChatPanel/helpers.ts``."""
|
||||
s = "" if value is None else str(value)
|
||||
return (
|
||||
s.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
.replace('"', """)
|
||||
.replace("'", "'")
|
||||
)
|
||||
|
||||
|
||||
def _node_display_name(node: dict[str, Any]) -> str:
|
||||
"""Prefer the user-set label (``input_default.name`` / ``metadata.title``);
|
||||
fall back to the block id."""
|
||||
defaults = node.get("input_default") or {}
|
||||
metadata = node.get("metadata") or {}
|
||||
for key in ("name", "title", "label"):
|
||||
value = defaults.get(key) or metadata.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
block_id = node.get("block_id") or ""
|
||||
return block_id or "unknown"
|
||||
|
||||
|
||||
def _format_nodes(nodes: list[dict[str, Any]]) -> str:
|
||||
if not nodes:
|
||||
return "<nodes>\n</nodes>"
|
||||
visible = nodes[:_MAX_NODES]
|
||||
lines = []
|
||||
for node in visible:
|
||||
node_id = _sanitize_for_xml(node.get("id") or "")
|
||||
name = _sanitize_for_xml(_node_display_name(node))
|
||||
block_id = _sanitize_for_xml(node.get("block_id") or "")
|
||||
lines.append(f"- {node_id}: {name} ({block_id})")
|
||||
extra = len(nodes) - len(visible)
|
||||
if extra > 0:
|
||||
lines.append(f"({extra} more not shown)")
|
||||
body = "\n".join(lines)
|
||||
return f"<nodes>\n{body}\n</nodes>"
|
||||
|
||||
|
||||
def _format_links(
|
||||
links: list[dict[str, Any]],
|
||||
nodes: list[dict[str, Any]],
|
||||
) -> str:
|
||||
if not links:
|
||||
return "<links>\n</links>"
|
||||
name_by_id = {n.get("id"): _node_display_name(n) for n in nodes}
|
||||
visible = links[:_MAX_LINKS]
|
||||
lines = []
|
||||
for link in visible:
|
||||
src_id = link.get("source_id") or ""
|
||||
dst_id = link.get("sink_id") or ""
|
||||
src_name = name_by_id.get(src_id, src_id)
|
||||
dst_name = name_by_id.get(dst_id, dst_id)
|
||||
src_out = link.get("source_name") or ""
|
||||
dst_in = link.get("sink_name") or ""
|
||||
lines.append(
|
||||
f"- {_sanitize_for_xml(src_name)}.{_sanitize_for_xml(src_out)} "
|
||||
f"-> {_sanitize_for_xml(dst_name)}.{_sanitize_for_xml(dst_in)}"
|
||||
)
|
||||
extra = len(links) - len(visible)
|
||||
if extra > 0:
|
||||
lines.append(f"({extra} more not shown)")
|
||||
body = "\n".join(lines)
|
||||
return f"<links>\n{body}\n</links>"
|
||||
|
||||
|
||||
async def build_builder_system_prompt_suffix(session: ChatSession) -> str:
|
||||
"""Return the cacheable system-prompt suffix for a builder session.
|
||||
|
||||
Holds only static content (dispatch guidance + building guide) so the
|
||||
bytes are identical across turns AND across sessions for different
|
||||
graphs — the live id/name/version ride on the per-turn prefix.
|
||||
"""
|
||||
if not session.metadata.builder_graph_id:
|
||||
return ""
|
||||
|
||||
try:
|
||||
guide = _load_guide()
|
||||
except Exception:
|
||||
logger.exception("[builder_context] Failed to load agent-building guide")
|
||||
return ""
|
||||
|
||||
# The guide is trusted server-side content (read from disk). We do NOT
|
||||
# escape it — the LLM needs the raw markdown to make sense of block ids,
|
||||
# code fences, and example JSON.
|
||||
return (
|
||||
f"\n\n<{BUILDER_SESSION_TAG}>\n"
|
||||
f"<run_agent_dispatch_mode>\n"
|
||||
f"{_BUILDER_RUN_AGENT_GUIDANCE}\n"
|
||||
f"</run_agent_dispatch_mode>\n"
|
||||
f"<building_guide>\n{guide}\n</building_guide>\n"
|
||||
f"</{BUILDER_SESSION_TAG}>"
|
||||
)
|
||||
|
||||
|
||||
async def build_builder_context_turn_prefix(
|
||||
session: ChatSession,
|
||||
user_id: str | None,
|
||||
) -> str:
|
||||
"""Return the per-turn ``<builder_context>`` prefix with the live
|
||||
graph snapshot (id/name/version/nodes/links). ``""`` for non-builder
|
||||
sessions; fetch-failure marker if the graph cannot be read."""
|
||||
graph_id = session.metadata.builder_graph_id
|
||||
if not graph_id:
|
||||
return ""
|
||||
|
||||
try:
|
||||
agent_json = await get_agent_as_json(graph_id, user_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[builder_context] Failed to fetch graph %s for session %s",
|
||||
graph_id,
|
||||
session.session_id,
|
||||
)
|
||||
return _FETCH_FAILED_PREFIX
|
||||
|
||||
if not agent_json:
|
||||
logger.warning(
|
||||
"[builder_context] Graph %s not found for session %s",
|
||||
graph_id,
|
||||
session.session_id,
|
||||
)
|
||||
return _FETCH_FAILED_PREFIX
|
||||
|
||||
version = _sanitize_for_xml(agent_json.get("version") or "")
|
||||
raw_name = agent_json.get("name")
|
||||
graph_name = (
|
||||
raw_name.strip() if isinstance(raw_name, str) and raw_name.strip() else None
|
||||
)
|
||||
nodes = agent_json.get("nodes") or []
|
||||
links = agent_json.get("links") or []
|
||||
name_attr = f' name="{_sanitize_for_xml(graph_name)}"' if graph_name else ""
|
||||
graph_tag = (
|
||||
f'<graph id="{_sanitize_for_xml(graph_id)}"'
|
||||
f"{name_attr} "
|
||||
f'version="{version}" '
|
||||
f'node_count="{len(nodes)}" '
|
||||
f'edge_count="{len(links)}"/>'
|
||||
)
|
||||
|
||||
inner = f"{graph_tag}\n{_format_nodes(nodes)}\n{_format_links(links, nodes)}"
|
||||
return f"<{BUILDER_CONTEXT_TAG}>\n{inner}\n</{BUILDER_CONTEXT_TAG}>\n\n"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user