mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Compare commits
1 Commits
fix/copilo
...
fix/artifa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
351001fdca |
@@ -25,8 +25,6 @@ Understand the **Why / What / How** before addressing comments — you need cont
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
```
|
||||
|
||||
> If GraphQL is rate-limited, `gh pr view` fails. See [GitHub rate limits](#github-rate-limits) for REST fallbacks.
|
||||
|
||||
## Fetch comments (all sources)
|
||||
|
||||
### 1. Inline review threads — GraphQL (primary source of actionable items)
|
||||
@@ -111,16 +109,12 @@ Only after this loop completes (all pages fetched, count confirmed) should you b
|
||||
|
||||
**Filter to unresolved threads only** — skip any thread where `isResolved: true`. `comments(last: 1)` returns the most recent comment in the thread — act on that; it reflects the reviewer's final ask. Use the thread `id` (Relay global ID) to track threads across polls.
|
||||
|
||||
> If GraphQL is rate-limited, see [GitHub rate limits](#github-rate-limits) for the REST fallback (flat comment list — no thread grouping or `isResolved`).
|
||||
|
||||
### 2. Top-level reviews — REST (MUST paginate)
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
|
||||
```
|
||||
|
||||
> **Already REST — unaffected by GraphQL rate limits or outages. Continue polling reviews normally even when GraphQL is exhausted.**
|
||||
|
||||
**CRITICAL — always `--paginate`.** Reviews default to 30 per page. PRs can have 80–170+ reviews (mostly empty resolution events). Without pagination you miss reviews past position 30 — including `autogpt-reviewer`'s structured review which is typically posted after several CI runs and sits well beyond the first page.
|
||||
|
||||
Two things to extract:
|
||||
@@ -139,8 +133,6 @@ Two things to extract:
|
||||
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
|
||||
```
|
||||
|
||||
> **Already REST — unaffected by GraphQL rate limits.**
|
||||
|
||||
Mostly contains: bot summaries (`coderabbitai[bot]`), CI/conflict detection (`github-actions[bot]`), and author status updates. Scan for non-empty messages from non-bot human reviewers that aren't the PR author — those are the ones that need a response.
|
||||
|
||||
## For each unaddressed comment
|
||||
@@ -335,65 +327,18 @@ git push
|
||||
|
||||
5. Restart the polling loop from the top — new commits reset CI status.
|
||||
|
||||
## GitHub rate limits
|
||||
## GitHub abuse rate limits
|
||||
|
||||
Three distinct rate limits exist — they have different causes, error shapes, and recovery times:
|
||||
Two distinct rate limits exist — they have different causes and recovery times:
|
||||
|
||||
| Error | HTTP code | Cause | Recovery |
|
||||
|---|---|---|---|
|
||||
| `{"code":"abuse"}` | 403 | Secondary rate limit — too many write operations (comments, mutations) in a short window | Wait **2–3 minutes**. 60s is often not enough. |
|
||||
| `{"message":"API rate limit exceeded"}` | 429 | Primary REST rate limit — 5000 calls/hr per user | Wait until `X-RateLimit-Reset` header timestamp |
|
||||
| `GraphQL: API rate limit already exceeded for user ID ...` | 403 on stderr, `gh` exits 1 | **GraphQL-specific** per-user limit — distinct from REST's 5000/hr and from the abuse secondary limit. Trips faster than REST because point costs per query. | Wait until the GraphQL window resets (typically ~1 hour from the first call in the window). REST still works — use fallbacks below. |
|
||||
| `{"message":"API rate limit exceeded"}` | 429 | Primary rate limit — too many API calls per hour | Wait until `X-RateLimit-Reset` header timestamp |
|
||||
|
||||
**Prevention:** Add `sleep 3` between individual thread reply API calls. When posting >20 replies, increase to `sleep 5`.
|
||||
|
||||
### Detection
|
||||
|
||||
The `gh` CLI surfaces the GraphQL limit on stderr with the exact string `GraphQL: API rate limit already exceeded for user ID <id>` and exits 1 — any `gh api graphql ...` **or** `gh pr view ...` call fails. Check current quota and reset time via the REST endpoint that reports GraphQL quota (this call is REST and still works whether GraphQL is rate-limited OR fully down):
|
||||
|
||||
```bash
|
||||
gh api rate_limit --jq '.resources.graphql' # { "limit": 5000, "used": 5000, "remaining": 0, "reset": 1729...}
|
||||
# Human-readable reset:
|
||||
gh api rate_limit --jq '.resources.graphql.reset' | xargs -I{} date -r {}
|
||||
```
|
||||
|
||||
Retry when `remaining > 0`. If you need to proceed sooner, sleep 2–5 min and probe again — the limit is per user, not per machine, so other concurrent agents under the same token also consume it.
|
||||
|
||||
### What keeps working
|
||||
|
||||
When GraphQL is unavailable (rate-limited or outage):
|
||||
|
||||
- **Keeps working (REST):** top-level reviews fetch, conversation comments fetch, all inline-comment replies, CI status (`gh pr checks`), and the `gh api rate_limit` probe.
|
||||
- **Degraded:** inline thread list — fall back to flat `/pulls/{N}/comments` REST, which drops thread grouping, `isResolved`, and Relay thread IDs. You still get comment bodies and the `databaseId` as `id`, enough to read and reply.
|
||||
- **Blocked:** `gh pr view`, the `resolveReviewThread` mutation, and any new `gh api graphql` queries — wait for the quota to reset.
|
||||
|
||||
### Fall back to REST
|
||||
|
||||
**PR metadata reads** — `gh pr view` uses GraphQL under the hood; use the REST pulls endpoint instead, which returns the full PR object:
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.body' # == --json body
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.base.ref' # == --json baseRefName
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.mergeable' # == --json mergeable
|
||||
```
|
||||
|
||||
Note: REST `mergeable` returns `true|false|null`; GraphQL returns `MERGEABLE|CONFLICTING|UNKNOWN`. The `null` case maps to `UNKNOWN` — treat it the same (still computing; poll again).
|
||||
|
||||
**Inline comments (flat list)** — no thread grouping or `isResolved`, but enough to read and reply:
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate \
|
||||
| jq '[.[] | {id, path, line, user: .user.login, body: .body[:200], in_reply_to_id}]'
|
||||
```
|
||||
|
||||
Use this degraded mode to make progress on the fix → reply loop, then return to GraphQL for `resolveReviewThread` once the rate limit resets.
|
||||
|
||||
**Replies** — already REST-native (`/pulls/{N}/comments/{ID}/replies`); no change needed, use the same command as the main flow.
|
||||
|
||||
**`resolveReviewThread`** — **no REST equivalent**; GitHub does not expose a REST endpoint for thread resolution. Queue the thread IDs needing resolution, wait for the GraphQL limit to reset, then run the resolve mutations in a batch (with `sleep 3` between calls, per the secondary-limit guidance).
|
||||
|
||||
### Recovery from secondary rate limit (403 abuse)
|
||||
|
||||
**Recovery from secondary rate limit (403):**
|
||||
1. Stop all API writes immediately
|
||||
2. Wait **2 minutes minimum** (not 60s — secondary limits are stricter)
|
||||
3. Resume with `sleep 3` between each call
|
||||
@@ -452,8 +397,6 @@ gh api graphql -f query='mutation { resolveReviewThread(input: {threadId: "THREA
|
||||
|
||||
**Never call this mutation before committing the fix.** The orchestrator will verify actual unresolved counts via GraphQL after you output `ORCHESTRATOR:DONE` — false resolutions will be caught and you will be re-briefed.
|
||||
|
||||
> `resolveReviewThread` is GraphQL-only — no REST equivalent. If GraphQL is rate-limited, see [GitHub rate limits](#github-rate-limits) for the queue-and-retry flow.
|
||||
|
||||
### Verify actual count before outputting ORCHESTRATOR:DONE
|
||||
|
||||
Before claiming "0 unresolved threads", always query GitHub directly — don't rely on your own bookkeeping. Paginate all pages — a single `first: 100` query misses threads beyond page 1:
|
||||
|
||||
@@ -1,275 +0,0 @@
|
||||
---
|
||||
name: pr-polish
|
||||
description: Alternate /pr-review and /pr-address on a PR until the PR is truly mergeable — no new review findings, zero unresolved inline threads, zero unaddressed top-level reviews or issue comments, all CI checks green, and two consecutive quiet polls after CI settles. Use when the user wants a PR polished to merge-ready without setting a fixed number of rounds.
|
||||
user-invocable: true
|
||||
argument-hint: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# PR Polish
|
||||
|
||||
**Goal.** Drive a PR to merge-ready by alternating `/pr-review` and `/pr-address` until **all** of the following hold:
|
||||
|
||||
1. The most recent `/pr-review` produces **zero new findings** (no new inline comments, no new top-level reviews with a non-empty body).
|
||||
2. Every inline review thread reachable via GraphQL reports `isResolved: true`.
|
||||
3. Every non-bot, non-author top-level review has been acknowledged (replied-to) OR resolved via a thread it spawned.
|
||||
4. Every non-bot, non-author issue comment has been acknowledged (replied-to).
|
||||
5. Every CI check is `conclusion: "success"` or `"skipped"` / `"neutral"` — none `"failure"` or still pending.
|
||||
6. **Two consecutive post-CI polls** (≥60s apart) stay clean — no new threads, no new non-empty reviews, no new issue comments. Bots (coderabbitai, sentry, autogpt-reviewer) frequently post late after CI settles; a single green snapshot is not sufficient.
|
||||
|
||||
**Do not stop at a fixed number of rounds.** If round N introduces new comments, round N+1 is required. Cap at `_MAX_ROUNDS = 10` as a safety valve, but expect 2–5 in practice.
|
||||
|
||||
## TodoWrite
|
||||
|
||||
Before starting, write two todos so the user can see the loop progression:
|
||||
|
||||
- `Round {current}: /pr-review + /pr-address on PR #{N}` — current iteration.
|
||||
- `Final polish polling: 2 consecutive clean polls, CI green, 0 unresolved` — runs after the last non-empty review round.
|
||||
|
||||
Update the `current` round counter at the start of each iteration; mark `completed` only when the round's address step finishes (all new threads addressed + resolved).
|
||||
|
||||
## Find the PR
|
||||
|
||||
```bash
|
||||
ARG_PR="${ARG:-}"
|
||||
# Normalize URL → numeric ID if the skill arg is a pull-request URL.
|
||||
if [[ "$ARG_PR" =~ ^https?://github\.com/[^/]+/[^/]+/pull/([0-9]+) ]]; then
|
||||
ARG_PR="${BASH_REMATCH[1]}"
|
||||
fi
|
||||
PR="${ARG_PR:-$(gh pr list --head "$(git branch --show-current)" --repo Significant-Gravitas/AutoGPT --json number --jq '.[0].number')}"
|
||||
if [ -z "$PR" ] || [ "$PR" = "null" ]; then
|
||||
echo "No PR found for current branch. Provide a PR number or URL as the skill arg."
|
||||
exit 1
|
||||
fi
|
||||
echo "Polishing PR #$PR"
|
||||
```
|
||||
|
||||
## The outer loop
|
||||
|
||||
```text
|
||||
round = 0
|
||||
while round < _MAX_ROUNDS:
|
||||
round += 1
|
||||
baseline = snapshot_state(PR) # see "Snapshotting state" below
|
||||
invoke_skill("pr-review", PR) # posts findings as inline comments / top-level review
|
||||
findings = diff_state(PR, baseline)
|
||||
if findings.total == 0:
|
||||
break # no new findings → go to polish polling
|
||||
invoke_skill("pr-address", PR) # resolves every unresolved thread + CI failure
|
||||
# Post-loop: polish polling (see below).
|
||||
polish_polling(PR)
|
||||
```
|
||||
|
||||
### Snapshotting state
|
||||
|
||||
Before each `/pr-review`, capture a baseline so the diff after the review reflects **only** what the review just added (not pre-existing threads):
|
||||
|
||||
```bash
|
||||
# Inline threads — total count + latest databaseId per thread
|
||||
gh api graphql -f query="
|
||||
{
|
||||
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
|
||||
pullRequest(number: ${PR}) {
|
||||
reviewThreads(first: 100) {
|
||||
totalCount
|
||||
nodes {
|
||||
id
|
||||
isResolved
|
||||
comments(last: 1) { nodes { databaseId } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}" > /tmp/baseline_threads.json
|
||||
|
||||
# Top-level reviews — count + latest id per non-empty review
|
||||
gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}/reviews" --paginate \
|
||||
--jq '[.[] | select((.body // "") != "") | {id, user: .user.login, state, submitted_at}]' \
|
||||
> /tmp/baseline_reviews.json
|
||||
|
||||
# Issue comments — count + latest id per non-bot, non-author comment.
|
||||
# Bots are filtered by User.type == "Bot" (GitHub sets this for app/bot
|
||||
# accounts like coderabbitai, github-actions, sentry-io). The author is
|
||||
# filtered by comparing login to the PR author — export it so jq can see it.
|
||||
AUTHOR=$(gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}" --jq '.user.login')
|
||||
gh api "repos/Significant-Gravitas/AutoGPT/issues/${PR}/comments" --paginate \
|
||||
--jq --arg author "$AUTHOR" \
|
||||
'[.[] | select(.user.type != "Bot" and .user.login != $author)
|
||||
| {id, user: .user.login, created_at}]' \
|
||||
> /tmp/baseline_issue_comments.json
|
||||
```
|
||||
|
||||
### Diffing after a review
|
||||
|
||||
After `/pr-review` runs, any of these counting as "new findings" means another address round is needed:
|
||||
|
||||
- New inline thread `id` not in the baseline.
|
||||
- An existing thread whose latest comment `databaseId` is higher than the baseline's (new reply on an old thread).
|
||||
- A new top-level review `id` with a non-empty body.
|
||||
- A new issue comment `id` from a non-bot, non-author user.
|
||||
|
||||
If any of the four buckets is non-empty → not done; invoke `/pr-address` and loop.
|
||||
|
||||
## Polish polling
|
||||
|
||||
Once `/pr-review` produces zero new findings, do **not** exit yet. Bots (coderabbitai, sentry, autogpt-reviewer) commonly post late reviews after CI settles — 30–90 seconds after the final push. Poll at 60-second intervals:
|
||||
|
||||
```text
|
||||
NON_SUCCESS_TERMINAL = {"failure", "cancelled", "timed_out", "action_required", "startup_failure"}
|
||||
clean_polls = 0
|
||||
required_clean = 2
|
||||
while clean_polls < required_clean:
|
||||
# 1. CI gate — any terminal non-success conclusion (not just "failure")
|
||||
# must trigger /pr-address. "success", "skipped", "neutral" are clean;
|
||||
# anything else (including cancelled, timed_out, action_required) is a
|
||||
# blocker that won't self-resolve.
|
||||
ci = fetch_check_runs(PR)
|
||||
if any ci.conclusion in NON_SUCCESS_TERMINAL:
|
||||
invoke_skill("pr-address", PR) # address failures + any new comments
|
||||
baseline = snapshot_state(PR) # reset — push during address invalidates old baseline
|
||||
clean_polls = 0
|
||||
continue
|
||||
if any ci.conclusion is None (still in_progress):
|
||||
sleep 60; continue # wait without counting this as clean
|
||||
|
||||
# 2. Comment / thread gate
|
||||
threads = fetch_unresolved_threads(PR)
|
||||
new_issue_comments = diff_against_baseline(issue_comments)
|
||||
new_reviews = diff_against_baseline(reviews)
|
||||
if threads or new_issue_comments or new_reviews:
|
||||
invoke_skill("pr-address", PR)
|
||||
baseline = snapshot_state(PR) # reset — the address loop just dealt with these,
|
||||
# otherwise they stay "new" relative to the old baseline forever
|
||||
clean_polls = 0
|
||||
continue
|
||||
|
||||
# 3. Mergeability gate
|
||||
mergeable = gh api repos/.../pulls/${PR} --jq '.mergeable'
|
||||
if mergeable == false (CONFLICTING):
|
||||
resolve_conflicts(PR) # see pr-address skill
|
||||
clean_polls = 0
|
||||
continue
|
||||
if mergeable is null (UNKNOWN):
|
||||
sleep 60; continue
|
||||
|
||||
clean_polls += 1
|
||||
sleep 60
|
||||
```
|
||||
|
||||
Only after `clean_polls == 2` do you report `ORCHESTRATOR:DONE`.
|
||||
|
||||
### Concrete CI fetch (don't parse `gh pr checks` text columns)
|
||||
|
||||
The `fetch_check_runs(PR)` step above must use `--json`, not the default text output. Job names can contain spaces and parentheses (e.g. `test (3.11)`, `Analyze (python)`), so `gh pr checks $PR | awk '{print $2}'` extracts `(3.11)` instead of the status — leading to a clean-poll firing while jobs are still pending.
|
||||
|
||||
```bash
|
||||
# Reliable: use --json so columns are unambiguous.
|
||||
ci_json=$(gh pr checks $PR --repo Significant-Gravitas/AutoGPT --json name,state,bucket)
|
||||
pending=$(echo "$ci_json" | jq '[.[] | select(.bucket == "pending")] | length')
|
||||
failed=$(echo "$ci_json" | jq '[.[] | select(.bucket == "fail" or .bucket == "cancel")] | length')
|
||||
|
||||
# Buckets are: pass | fail | pending | cancel | skipping
|
||||
# (NOTE: gh pr checks does NOT expose `conclusion` as a JSON field —
|
||||
# only `bucket`. Don't confuse with the GitHub REST API's check_runs
|
||||
# endpoint, which DOES use conclusion.)
|
||||
```
|
||||
|
||||
Map back to the pseudocode above: `bucket == "pending"` is `ci.conclusion is None (still in_progress)`; `bucket in {"fail", "cancel"}` is `ci.conclusion in NON_SUCCESS_TERMINAL`; `bucket in {"pass", "skipping"}` is clean.
|
||||
|
||||
### Why 2 clean polls, not 1
|
||||
|
||||
A single green snapshot can be misleading — the final CI check often completes ~30s before a bot posts its delayed review. One quiet cycle does not prove the PR is stable; two consecutive cycles with no new threads, reviews, or issue comments arriving gives high confidence nothing else is incoming.
|
||||
|
||||
### Why checking every source each poll
|
||||
|
||||
`/pr-address` polling inside a single round already re-checks its own comments, but `/pr-polish` sits a level above and must also catch:
|
||||
|
||||
- New top-level reviews (autogpt-reviewer sometimes posts structured feedback only after several CI green cycles).
|
||||
- Issue comments from human reviewers (not caught by inline thread polling).
|
||||
- Sentry bug predictions that land on new line numbers post-push.
|
||||
- Merge conflicts introduced by a race between your push and a merge to `dev`.
|
||||
|
||||
## Invocation pattern
|
||||
|
||||
Delegate to existing skills with the `Skill` tool; do not re-implement the review or address logic inline. This keeps the polish loop focused on orchestration and lets the child skills evolve independently.
|
||||
|
||||
```python
|
||||
Skill(skill="pr-review", args=pr_url)
|
||||
Skill(skill="pr-address", args=pr_url)
|
||||
```
|
||||
|
||||
After each child invocation, re-query GitHub state directly — never trust a summary for the stop condition. The orchestrator's `ORCHESTRATOR:DONE` is verified against actual GraphQL / REST responses per the rules in `pr-address`'s "Verify actual count before outputting ORCHESTRATOR:DONE" section.
|
||||
|
||||
### **Auto-continue: do NOT end your response between child skills**
|
||||
|
||||
`/pr-polish` is a single orchestration task — one invocation drives the PR all the way to merge-ready. When a child `Skill()` call returns control to you:
|
||||
|
||||
- Do NOT summarize and stop.
|
||||
- Do NOT wait for user confirmation to continue.
|
||||
- Immediately, in the same response, perform the next loop step: state diff → decide next action → next `Skill()` call or polling sleep.
|
||||
|
||||
The child skill returning is a **loop iteration boundary**, not a conversation turn boundary. You are expected to keep going until one of the exit conditions in the opening section is met (2 consecutive clean polls, `_MAX_ROUNDS` hit, or an unrecoverable error).
|
||||
|
||||
If the user needs to approve a risky action mid-loop (e.g., a force-push or a destructive git operation), pause there — but not at the routine "round N finished, round N+1 needed" boundary. Those are silent transitions.
|
||||
|
||||
### **Run /pr-polish in the foreground — never in a background agent**
|
||||
|
||||
Spawning `/pr-polish` inside an `Agent(subagent_type="general-purpose")` background task **does not work**. Background agents don't inherit the parent's slash-command registry, so `Skill(skill="pr-review")` and `Skill(skill="pr-address")` calls aren't available — the agent has to manually replicate the child skills' logic, which is fragile and tends to stall on the first network or rate-limit hiccup. Symptom: the background task reports `stalled: no progress for 600s` mid-review.
|
||||
|
||||
Run `/pr-polish` inline in the foreground conversation. If the user asks for "/pr-polish + /pr-test in parallel", split them: foreground `/pr-polish`, and ONLY then can the test step go to a background agent (because `/pr-test` doesn't itself need to invoke skills).
|
||||
|
||||
### **You MUST invoke `Skill(pr-review)` every round — even when bot reviews already exist**
|
||||
|
||||
A common failure mode: CodeRabbit / autogpt-reviewer / Sentry have already posted findings on the PR, and the orchestrator skips the `Skill(pr-review)` step on the assumption that "review has been done." That's wrong — the outer loop's purpose is to layer **the agent's own review** on top of the bot reviews, catching issues the bots miss (architecture, naming, cross-file invariants, hidden coupling). If the orchestrator only addresses bot findings without ever running its own review, the loop converges to "bot-clean" but not "agent-reviewed-clean," and the user reasonably asks "did /pr-polish even read the diff?"
|
||||
|
||||
**Self-check before reporting `ORCHESTRATOR:DONE`:** confirm at least one `Skill(skill="pr-review")` call appears in the current orchestration. If none, the loop is incomplete — go back and run one round.
|
||||
|
||||
## GitHub rate limits
|
||||
|
||||
This skill issues many GraphQL calls (one review-thread query per outer iteration plus per-poll queries inside polish polling). Expect the GraphQL budget to be tight on large PRs. When `gh api rate_limit --jq .resources.graphql.remaining` drops below ~200, back off:
|
||||
|
||||
- Fall back to REST for reads (flat `/pulls/{N}/comments`, `/pulls/{N}/reviews`, `/issues/{N}/comments`) per the `pr-address` skill's GraphQL-fallback section.
|
||||
- Queue thread resolutions (GraphQL-only) until the budget resets; keep making progress on fixes + REST replies meanwhile.
|
||||
- `sleep 5` between any batch of ≥20 writes to avoid secondary rate limits.
|
||||
|
||||
## Safety valves
|
||||
|
||||
- `_MAX_ROUNDS = 10` — if review+address rounds exceed this, stop and escalate to the user with a summary of what's still unresolved. A PR that cannot converge in 10 rounds has systemic issues that need human judgment.
|
||||
- After each commit, run `poetry run format` / `pnpm format && pnpm lint && pnpm types` per the target codebase's conventions. A failing format check is CI `failure` that will never self-resolve.
|
||||
- Every `/pr-review` round checks for **duplicate** concerns first (via `pr-review`'s own "Fetch existing review comments" step) so the loop does not re-post the same finding that a prior round already resolved.
|
||||
|
||||
## Reporting
|
||||
|
||||
When the skill finishes (either via two clean polls or hitting `_MAX_ROUNDS`), produce a compact summary:
|
||||
|
||||
```
|
||||
PR #{N} polish complete ({rounds_completed} rounds):
|
||||
- {X} inline threads opened and resolved
|
||||
- {Y} CI failures fixed
|
||||
- {Z} new commits pushed
|
||||
Final state: CI green, {total} threads all resolved, mergeable.
|
||||
```
|
||||
|
||||
If exiting via `_MAX_ROUNDS`, flag explicitly:
|
||||
|
||||
```
|
||||
PR #{N} polish stopped at {_MAX_ROUNDS} rounds — NOT merge-ready:
|
||||
- {N} threads still unresolved: {titles}
|
||||
- CI status: {summary}
|
||||
Needs human review.
|
||||
```
|
||||
|
||||
## When to use this skill
|
||||
|
||||
Use when the user says any of:
|
||||
- "polish this PR"
|
||||
- "keep reviewing and addressing until it's mergeable"
|
||||
- "loop /pr-review + /pr-address until done"
|
||||
- "make sure the PR is actually merge-ready"
|
||||
|
||||
Do **not** use when:
|
||||
- User wants just one review pass (→ `/pr-review`).
|
||||
- User wants to address already-posted comments without further self-review (→ `/pr-address`).
|
||||
- A fixed round count is explicitly requested (e.g., "do 3 rounds") — honour the count instead of converging.
|
||||
@@ -5,7 +5,7 @@ user-invocable: true
|
||||
argument-hint: "[worktree path or PR number] — tests the PR in the given worktree. Optional flags: --fix (auto-fix issues found)"
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "2.1.0"
|
||||
version: "2.0.0"
|
||||
---
|
||||
|
||||
# Manual E2E Test
|
||||
@@ -180,120 +180,6 @@ Based on the PR analysis, write a test plan to `$RESULTS_DIR/test-plan.md`:
|
||||
|
||||
**Be critical** — include edge cases, error paths, and security checks. Every scenario MUST specify what screenshots to take and what state to verify.
|
||||
|
||||
## Step 3.0: Claim the testing lock (coordinate parallel agents)
|
||||
|
||||
Multiple worktrees share the same host — Docker infra (postgres, redis, clamav), app ports (3000/8006/…), and the test user. Two agents running `/pr-test` concurrently will corrupt each other's state (connection-pool exhaustion, port binds failing silently, cross-test assertions). Use the root-worktree lock file to take turns.
|
||||
|
||||
### Lock file contract
|
||||
|
||||
Path (**always** the root worktree so all siblings see it): `$REPO_ROOT/.ign.testing.lock`
|
||||
|
||||
Body (one `key=value` per line):
|
||||
```
|
||||
holder=<pr-XXXXX-purpose>
|
||||
pid=<pid-or-"self">
|
||||
started=<iso8601>
|
||||
heartbeat=<iso8601, updated every ~2 min>
|
||||
worktree=<full path>
|
||||
branch=<branch name>
|
||||
intent=<one-line description + rough duration>
|
||||
```
|
||||
|
||||
### Claim
|
||||
|
||||
```bash
|
||||
LOCK=$REPO_ROOT/.ign.testing.lock
|
||||
NOW=$(date -u +%Y-%m-%dT%H:%MZ)
|
||||
STALE_AFTER_MIN=5
|
||||
|
||||
if [ -f "$LOCK" ]; then
|
||||
HB=$(grep '^heartbeat=' "$LOCK" | cut -d= -f2)
|
||||
HB_EPOCH=$(date -j -f '%Y-%m-%dT%H:%MZ' "$HB" +%s 2>/dev/null || date -d "$HB" +%s 2>/dev/null || echo 0)
|
||||
AGE_MIN=$(( ( $(date -u +%s) - HB_EPOCH ) / 60 ))
|
||||
if [ "$AGE_MIN" -gt "$STALE_AFTER_MIN" ]; then
|
||||
echo "WARN: stale lock (${AGE_MIN}m old) — reclaiming"
|
||||
cat "$LOCK" | sed 's/^/ stale: /'
|
||||
else
|
||||
echo "Another agent holds the lock:"; cat "$LOCK"
|
||||
echo "Wait until released or resume after $((STALE_AFTER_MIN - AGE_MIN))m."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
cat > "$LOCK" <<EOF
|
||||
holder=pr-${PR_NUMBER}-e2e
|
||||
pid=self
|
||||
started=$NOW
|
||||
heartbeat=$NOW
|
||||
worktree=$WORKTREE_PATH
|
||||
branch=$(cd $WORKTREE_PATH && git branch --show-current)
|
||||
intent=E2E test PR #${PR_NUMBER}, native mode, ~60min
|
||||
EOF
|
||||
echo "Lock claimed"
|
||||
```
|
||||
|
||||
### Heartbeat (MUST run in background during the whole test)
|
||||
|
||||
Without a heartbeat a crashed agent keeps the lock forever. Run this as a background process right after claim:
|
||||
|
||||
```bash
|
||||
(while true; do
|
||||
sleep 120
|
||||
[ -f "$LOCK" ] || exit 0 # lock released → exit heartbeat
|
||||
perl -i -pe "s/^heartbeat=.*/heartbeat=$(date -u +%Y-%m-%dT%H:%MZ)/" "$LOCK"
|
||||
done) &
|
||||
HEARTBEAT_PID=$!
|
||||
echo "$HEARTBEAT_PID" > /tmp/pr-test-heartbeat.pid
|
||||
```
|
||||
|
||||
### Release (always — even on failure)
|
||||
|
||||
```bash
|
||||
kill "$HEARTBEAT_PID" 2>/dev/null
|
||||
rm -f "$LOCK" /tmp/pr-test-heartbeat.pid
|
||||
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock" \
|
||||
>> $REPO_ROOT/.ign.testing.log
|
||||
```
|
||||
|
||||
Use a `trap` so release runs even on `exit 1`:
|
||||
```bash
|
||||
trap 'kill "$HEARTBEAT_PID" 2>/dev/null; rm -f "$LOCK"' EXIT INT TERM
|
||||
```
|
||||
|
||||
### **Release the lock AS SOON AS the test run is done**
|
||||
|
||||
The lock guards **test execution**, not **app lifecycle**. Once Step 5 (record results) and Step 6 (post PR comment) are complete, release the lock IMMEDIATELY — even if:
|
||||
|
||||
- The native `poetry run app` / `pnpm dev` processes are still running so the user can keep poking at the app manually.
|
||||
- You're leaving docker containers up.
|
||||
- You're tailing logs for a minute or two.
|
||||
|
||||
Keeping the lock held past the test run is the single most common way `/pr-test` stalls other agents. **The app staying up is orthogonal to the lock; don't conflate them.** Sibling worktrees running their own `/pr-test` will kill the stray processes and free the ports themselves (Step 3c/3e-native handle that) — they just need the lock file gone.
|
||||
|
||||
Concretely, the sequence at the end of every `/pr-test` run (success or failure) is:
|
||||
|
||||
```bash
|
||||
# 1. Write the final report + post PR comment — done above in Step 5/6.
|
||||
# 2. Release the lock right now, even if the app is still up.
|
||||
kill "$HEARTBEAT_PID" 2>/dev/null
|
||||
rm -f "$LOCK" /tmp/pr-test-heartbeat.pid
|
||||
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock (app may still be running)" \
|
||||
>> $REPO_ROOT/.ign.testing.log
|
||||
# 3. Optionally leave the app running and note it so the user knows:
|
||||
echo "Native stack still running on :3000 / :8006 for manual poking. Kill with:"
|
||||
echo " pkill -9 -f 'poetry run app'; pkill -9 -f 'next-server|next dev'"
|
||||
```
|
||||
|
||||
If a sibling agent's `/pr-test` needs to take over, it'll do the kill+rebuild dance from Step 3c/3e-native on its own — your only job is to not hold the lock file past the end of your test.
|
||||
|
||||
### Shared status log
|
||||
|
||||
`$REPO_ROOT/.ign.testing.log` is an append-only channel any agent can read/write. Use it for "I'm waiting", "I'm done, resources free", or post-run notes:
|
||||
```bash
|
||||
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] <message>" \
|
||||
>> $REPO_ROOT/.ign.testing.log
|
||||
```
|
||||
|
||||
## Step 3: Environment setup
|
||||
|
||||
### 3a. Copy .env files from the root worktree
|
||||
@@ -362,87 +248,7 @@ docker ps --format "{{.Names}}" | grep -E "rest_server|executor|copilot|websocke
|
||||
done
|
||||
```
|
||||
|
||||
**Native mode also:** when running the app natively (see 3e-native), kill any stray host processes and free the app ports before starting — otherwise `poetry run app` and `pnpm dev` will fail to bind.
|
||||
|
||||
```bash
|
||||
# Kill stray native app processes from prior runs
|
||||
pkill -9 -f "python.*backend" 2>/dev/null || true
|
||||
pkill -9 -f "poetry run app" 2>/dev/null || true
|
||||
pkill -9 -f "next-server|next dev" 2>/dev/null || true
|
||||
|
||||
# Free app ports (errors per port are ignored — port may simply be unused)
|
||||
for port in 3000 8006 8001 8002 8005 8008; do
|
||||
lsof -ti :$port -sTCP:LISTEN | xargs -r kill -9 2>/dev/null || true
|
||||
done
|
||||
```
|
||||
|
||||
### 3e-native. Run the app natively (PREFERRED for iterative dev)
|
||||
|
||||
Native mode runs infra (postgres, supabase, redis, rabbitmq, clamav) in docker but runs the backend and frontend directly on the host. This avoids the 3-8 minute `docker compose build` cycle on every backend change — code edits are picked up on process restart (seconds) instead of a full image rebuild.
|
||||
|
||||
**When to prefer native mode (default for this skill):**
|
||||
- Iterative dev/debug loops where you're editing backend or frontend code between test runs
|
||||
- Any PR that touches Python/TS source but not Dockerfiles, compose config, or infra images
|
||||
- Fast repro of a failing scenario — restart `poetry run app` in a couple of seconds
|
||||
|
||||
**When to prefer docker mode (3e fallback):**
|
||||
- Testing changes to `Dockerfile`, `docker-compose.yml`, or base images
|
||||
- Production-parity smoke tests (exact container env, networking, volumes)
|
||||
- CI-equivalent runs where you need the exact image that'll ship
|
||||
|
||||
**Note on 3b (copilot auth):** no npm install anywhere. `poetry install` pulls in `claude_agent_sdk`, which ships its own Claude CLI binary — available on `PATH` whenever you run commands via `poetry run` (native) OR whenever the copilot_executor container is built from its Poetry lockfile (docker). The OAuth token extraction still applies (same `refresh_claude_token.sh` call).
|
||||
|
||||
**Preamble:** before starting native, run the kill-stray + free-ports block from 3c's "Native mode also" subsection.
|
||||
|
||||
**1. Start infra only (one-time per session):**
|
||||
|
||||
```bash
|
||||
cd $PLATFORM_DIR && docker compose --profile local up deps --detach --remove-orphans --build
|
||||
```
|
||||
|
||||
This brings up postgres/supabase/redis/rabbitmq/clamav and skips all app services.
|
||||
|
||||
**2. Start the backend natively:**
|
||||
|
||||
```bash
|
||||
cd $BACKEND_DIR && (poetry run app 2>&1 | tee .ign.application.logs) &
|
||||
```
|
||||
|
||||
`poetry run app` spawns **all** app subprocesses — `rest_server`, `executor`, `copilot_executor`, `websocket`, `scheduler`, `notification_server`, `database_manager` — inside ONE parent process. No separate containers, no separate terminals. The `.ign.application.logs` prefix is already gitignored.
|
||||
|
||||
**3. Wait for the backend on :8006 BEFORE starting the frontend.** This ordering matters — the frontend's `pnpm dev` startup invokes `generate-api-queries`, which fetches `/openapi.json` from the backend. If the backend isn't listening yet, `pnpm dev` fails immediately.
|
||||
|
||||
```bash
|
||||
for i in $(seq 1 60); do
|
||||
if [ "$(curl -s -o /dev/null -w '%{http_code}' http://localhost:8006/docs 2>/dev/null)" = "200" ]; then
|
||||
echo "Backend ready"
|
||||
break
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
```
|
||||
|
||||
**4. Start the frontend natively:**
|
||||
|
||||
```bash
|
||||
cd $FRONTEND_DIR && (pnpm dev 2>&1 | tee .ign.frontend.logs) &
|
||||
```
|
||||
|
||||
**5. Wait for the frontend on :3000:**
|
||||
|
||||
```bash
|
||||
for i in $(seq 1 60); do
|
||||
if [ "$(curl -s -o /dev/null -w '%{http_code}' http://localhost:3000 2>/dev/null)" = "200" ]; then
|
||||
echo "Frontend ready"
|
||||
break
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
```
|
||||
|
||||
Once both are up, skip 3e/3f and go straight to **3g/3h** (feature flags / test user creation).
|
||||
|
||||
### 3e. Build and start (docker — fallback)
|
||||
### 3e. Build and start
|
||||
|
||||
```bash
|
||||
cd $PLATFORM_DIR && docker compose build --no-cache 2>&1 | tail -20
|
||||
@@ -636,22 +442,6 @@ agent-browser --session-name pr-test snapshot | grep "text:"
|
||||
|
||||
### Checking logs
|
||||
|
||||
**Native mode:** when running via `poetry run app` + `pnpm dev`, all app logs stream to the `.ign.*.logs` files written by the `tee` pipes in 3e-native. `rest_server`, `executor`, `copilot_executor`, `websocket`, `scheduler`, `notification_server`, and `database_manager` are all subprocesses of the single `poetry run app` parent, so their output is interleaved in `.ign.application.logs`.
|
||||
|
||||
```bash
|
||||
# Backend (all app subprocesses interleaved)
|
||||
tail -f $BACKEND_DIR/.ign.application.logs
|
||||
|
||||
# Frontend (Next.js dev server)
|
||||
tail -f $FRONTEND_DIR/.ign.frontend.logs
|
||||
|
||||
# Filter for errors across either log
|
||||
grep -iE "error|exception|traceback" $BACKEND_DIR/.ign.application.logs | tail -20
|
||||
grep -iE "error|exception|traceback" $FRONTEND_DIR/.ign.frontend.logs | tail -20
|
||||
```
|
||||
|
||||
**Docker mode:**
|
||||
|
||||
```bash
|
||||
# Backend REST server
|
||||
docker logs autogpt_platform-rest_server-1 2>&1 | tail -30
|
||||
@@ -781,19 +571,6 @@ Upload screenshots to the PR using the GitHub Git API (no local git operations
|
||||
|
||||
**CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.** Every screenshot MUST appear as `` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes.
|
||||
|
||||
**CRITICAL — NEVER paste absolute local paths into the PR comment.** Strings like `/Users/…`, `/home/…`, `C:\…` are useless to every reviewer except you. Before posting, grep the final body for `/Users/`, `/home/`, `/tmp/`, `/private/`, `C:\`, `~/` and either drop those lines entirely or rewrite them as repo-relative paths (`autogpt_platform/backend/…`). The PR comment is an artifact reviewers on GitHub read — it must be self-contained on github.com. Keep local paths in `$RESULTS_DIR/test-report.md` for yourself; only copy the *content* they reference (excerpts, test names, log lines) into the PR comment, not the path.
|
||||
|
||||
**Pre-post sanity check** (paste after building the comment body, before `gh api ... comments`):
|
||||
|
||||
```bash
|
||||
# Reject any local-looking absolute path or home-dir shortcut in the body
|
||||
if grep -nE '(^|[^A-Za-z])(/Users/|/home/|/tmp/|/private/|C:\\|~/)[A-Za-z0-9]' "$COMMENT_FILE" ; then
|
||||
echo "ABORT: local filesystem paths detected in PR comment body."
|
||||
echo "Remove or rewrite as repo-relative (autogpt_platform/...) before posting."
|
||||
exit 1
|
||||
fi
|
||||
```
|
||||
|
||||
```bash
|
||||
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
|
||||
REPO="Significant-Gravitas/AutoGPT"
|
||||
@@ -1099,15 +876,9 @@ test scenario → find issue (bug OR UX problem) → screenshot broken state
|
||||
### Problem: Frontend shows cookie banner blocking interaction
|
||||
**Fix:** `agent-browser click 'text=Accept All'` before other interactions.
|
||||
|
||||
### Problem: Claude CLI not found in copilot_executor container
|
||||
**Symptom:** Copilot logs say `claude: command not found` or similar when starting an SDK turn.
|
||||
**Cause:** Image was built without `poetry install` (stale base layer, or Dockerfile bypass). The SDK CLI ships inside the `claude_agent_sdk` Poetry dep — it is NOT an npm package.
|
||||
**Fix:** Rebuild the image cleanly: `docker compose build --no-cache copilot_executor && docker compose up -d copilot_executor`. Do NOT `docker exec ... npm install -g @anthropic-ai/claude-code` — that is outdated guidance and will pollute the container with a second CLI that the SDK won't use.
|
||||
|
||||
### Problem: agent-browser screenshot hangs / times out
|
||||
**Symptom:** `agent-browser screenshot` exits with code 124 even on `about:blank`.
|
||||
**Cause:** Stuck CDP connection or Chromium process tree. Seen on macOS when a prior `/pr-test` left a zombie Chrome for Testing.
|
||||
**Fix:** `pkill -9 -f "agent-browser|chromium|Chrome for Testing" && sleep 2`, then reopen the browser with a fresh `--session-name`. If still failing, verify via `agent-browser eval` + `agent-browser snapshot` (DOM state) instead of relying on PNGs — the feature under test is the same.
|
||||
### Problem: Container loses npm packages after rebuild
|
||||
**Cause:** `docker compose up --build` rebuilds the image, losing runtime installs.
|
||||
**Fix:** Add packages to the Dockerfile instead of installing at runtime.
|
||||
|
||||
### Problem: Services not starting after `docker compose up`
|
||||
**Fix:** Wait and check health: `docker compose ps`. Common cause: migration hasn't finished. Check: `docker logs autogpt_platform-migrate-1 2>&1 | tail -5`. If supabase-db isn't healthy: `docker restart supabase-db && sleep 10`.
|
||||
|
||||
79
.github/workflows/platform-backend-ci.yml
vendored
79
.github/workflows/platform-backend-ci.yml
vendored
@@ -119,12 +119,10 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
# Redis is provisioned as a real 3-shard cluster below via docker
|
||||
# run (see the "Start Redis Cluster" step). GHA services can't
|
||||
# override the image CMD or stand up multi-container clusters, so
|
||||
# that setup is inlined — it mirrors the topology of the local dev
|
||||
# compose stack (autogpt_platform/docker-compose.platform.yml) and
|
||||
# prod helm chart.
|
||||
redis:
|
||||
image: redis:latest
|
||||
ports:
|
||||
- 6379:6379
|
||||
rabbitmq:
|
||||
image: rabbitmq:4.1.4
|
||||
ports:
|
||||
@@ -168,68 +166,6 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Start Redis Cluster (3 shards)
|
||||
run: |
|
||||
# 3-master Redis Cluster matching the local compose stack
|
||||
# (autogpt_platform/docker-compose.platform.yml) and prod. Each
|
||||
# shard runs in its own container on a dedicated bridge network,
|
||||
# announces its compose-style hostname for intra-network clients,
|
||||
# and publishes 1700N on the GHA host so tests can reach every
|
||||
# shard via localhost. The backend's ``_address_remap`` rewrites
|
||||
# every CLUSTER SLOTS reply to localhost:<announced-port>, which
|
||||
# picks the right published port per shard.
|
||||
#
|
||||
# Not reusing docker-compose.platform.yml directly because compose
|
||||
# validates the full file even when only some services are ``up``,
|
||||
# and that file references services (db/kong/...) defined in a
|
||||
# sibling compose file — pulling both in would needlessly couple
|
||||
# CI to the full local-dev stack.
|
||||
docker network create redis-cluster-ci
|
||||
for i in 0 1 2; do
|
||||
port=$((17000 + i))
|
||||
bus=$((27000 + i))
|
||||
docker run -d --name redis-$i --network redis-cluster-ci \
|
||||
--network-alias redis-$i \
|
||||
-p $port:$port \
|
||||
redis:7 \
|
||||
redis-server --port $port \
|
||||
--cluster-enabled yes \
|
||||
--cluster-config-file nodes.conf \
|
||||
--cluster-node-timeout 5000 \
|
||||
--cluster-require-full-coverage no \
|
||||
--cluster-announce-hostname redis-$i \
|
||||
--cluster-announce-port $port \
|
||||
--cluster-announce-bus-port $bus \
|
||||
--cluster-preferred-endpoint-type hostname
|
||||
done
|
||||
# Wait for each shard to accept commands.
|
||||
for i in 0 1 2; do
|
||||
port=$((17000 + i))
|
||||
for _ in $(seq 1 30); do
|
||||
docker exec redis-$i redis-cli -p $port ping 2>/dev/null | grep -q PONG && break
|
||||
sleep 1
|
||||
done
|
||||
done
|
||||
# Form the cluster from an init container on the same network so
|
||||
# --cluster-preferred-endpoint-type hostname resolves redis-0/1/2.
|
||||
docker run --rm --network redis-cluster-ci redis:7 \
|
||||
redis-cli --cluster create \
|
||||
redis-0:17000 redis-1:17001 redis-2:17002 \
|
||||
--cluster-replicas 0 --cluster-yes
|
||||
# Confirm convergence.
|
||||
for _ in $(seq 1 30); do
|
||||
state=$(docker exec redis-0 redis-cli -p 17000 cluster info | awk -F: '/^cluster_state:/ {print $2}' | tr -d '[:cntrl:]')
|
||||
if [ "$state" = "ok" ]; then
|
||||
echo "Redis Cluster ready (3 shards, state=ok)"
|
||||
docker exec redis-0 redis-cli -p 17000 cluster nodes
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
echo "Redis Cluster failed to reach ok state" >&2
|
||||
docker exec redis-0 redis-cli -p 17000 cluster info >&2 || true
|
||||
exit 1
|
||||
|
||||
- name: Setup Supabase
|
||||
uses: supabase/setup-cli@v1
|
||||
with:
|
||||
@@ -350,13 +286,8 @@ jobs:
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "17000"
|
||||
REDIS_PORT: "6379"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
# Opt-in: lets backend/data/e2e_redis_restart_test.py spin up its
|
||||
# own isolated 3-shard cluster (ports 27110–27112) and exercise
|
||||
# ``docker restart <shard>`` mid-stream. Off locally so a
|
||||
# contributor's ``poetry run test`` doesn't pay the ~15s cost.
|
||||
E2E_RESTART_ISOLATED: "1"
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: ${{ !cancelled() }}
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -195,8 +195,3 @@ test.db
|
||||
# Implementation plans (generated by AI agents)
|
||||
plans/
|
||||
.claude/worktrees/
|
||||
test-results/
|
||||
|
||||
# Playwright MCP / local browser-testing artifacts
|
||||
.playwright-mcp/
|
||||
copilot-session-switch-qa/
|
||||
|
||||
@@ -267,7 +267,7 @@
|
||||
"filename": "autogpt_platform/backend/backend/blocks/replicate/replicate_block.py",
|
||||
"hashed_secret": "8bbdd6f26368f58ea4011d13d7f763cb662e66f0",
|
||||
"is_verified": false,
|
||||
"line_number": 67
|
||||
"line_number": 55
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/slant3d/webhook.py": [
|
||||
@@ -467,5 +467,5 @@
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_at": "2026-04-24T16:42:44Z"
|
||||
"generated_at": "2026-04-09T14:20:23Z"
|
||||
}
|
||||
|
||||
3
autogpt_platform/.gitignore
vendored
3
autogpt_platform/.gitignore
vendored
@@ -1,6 +1,3 @@
|
||||
*.ignore.*
|
||||
*.ign.*
|
||||
.application.logs
|
||||
|
||||
# Claude Code local settings only — the rest of .claude/ is shared (skills etc.)
|
||||
.claude/settings.local.json
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class RateLimitSettings(BaseSettings):
|
||||
redis_host: str = Field(
|
||||
default="redis://localhost:6379",
|
||||
description="Redis host",
|
||||
validation_alias="REDIS_HOST",
|
||||
)
|
||||
|
||||
redis_port: str = Field(
|
||||
default="6379", description="Redis port", validation_alias="REDIS_PORT"
|
||||
)
|
||||
|
||||
redis_password: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Redis password",
|
||||
validation_alias="REDIS_PASSWORD",
|
||||
)
|
||||
|
||||
requests_per_minute: int = Field(
|
||||
default=60,
|
||||
description="Maximum number of requests allowed per minute per API key",
|
||||
validation_alias="RATE_LIMIT_REQUESTS_PER_MINUTE",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(case_sensitive=True, extra="ignore")
|
||||
|
||||
|
||||
RATE_LIMIT_SETTINGS = RateLimitSettings()
|
||||
@@ -0,0 +1,51 @@
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
from redis import Redis
|
||||
|
||||
from .config import RATE_LIMIT_SETTINGS
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(
|
||||
self,
|
||||
redis_host: str = RATE_LIMIT_SETTINGS.redis_host,
|
||||
redis_port: str = RATE_LIMIT_SETTINGS.redis_port,
|
||||
redis_password: str | None = RATE_LIMIT_SETTINGS.redis_password,
|
||||
requests_per_minute: int = RATE_LIMIT_SETTINGS.requests_per_minute,
|
||||
):
|
||||
self.redis = Redis(
|
||||
host=redis_host,
|
||||
port=int(redis_port),
|
||||
password=redis_password,
|
||||
decode_responses=True,
|
||||
)
|
||||
self.window = 60
|
||||
self.max_requests = requests_per_minute
|
||||
|
||||
async def check_rate_limit(self, api_key_id: str) -> Tuple[bool, int, int]:
|
||||
"""
|
||||
Check if request is within rate limits.
|
||||
|
||||
Args:
|
||||
api_key_id: The API key identifier to check
|
||||
|
||||
Returns:
|
||||
Tuple of (is_allowed, remaining_requests, reset_time)
|
||||
"""
|
||||
now = time.time()
|
||||
window_start = now - self.window
|
||||
key = f"ratelimit:{api_key_id}:1min"
|
||||
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.zremrangebyscore(key, 0, window_start)
|
||||
pipe.zadd(key, {str(now): now})
|
||||
pipe.zcount(key, window_start, now)
|
||||
pipe.expire(key, self.window)
|
||||
|
||||
_, _, request_count, _ = pipe.execute()
|
||||
|
||||
remaining = max(0, self.max_requests - request_count)
|
||||
reset_time = int(now + self.window)
|
||||
|
||||
return request_count <= self.max_requests, remaining, reset_time
|
||||
@@ -0,0 +1,32 @@
|
||||
from fastapi import HTTPException, Request
|
||||
from starlette.middleware.base import RequestResponseEndpoint
|
||||
|
||||
from .limiter import RateLimiter
|
||||
|
||||
|
||||
async def rate_limit_middleware(request: Request, call_next: RequestResponseEndpoint):
|
||||
"""FastAPI middleware for rate limiting API requests."""
|
||||
limiter = RateLimiter()
|
||||
|
||||
if not request.url.path.startswith("/api"):
|
||||
return await call_next(request)
|
||||
|
||||
api_key = request.headers.get("Authorization")
|
||||
if not api_key:
|
||||
return await call_next(request)
|
||||
|
||||
api_key = api_key.replace("Bearer ", "")
|
||||
|
||||
is_allowed, remaining, reset_time = await limiter.check_rate_limit(api_key)
|
||||
|
||||
if not is_allowed:
|
||||
raise HTTPException(
|
||||
status_code=429, detail="Rate limit exceeded. Please try again later."
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(limiter.max_requests)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
response.headers["X-RateLimit-Reset"] = str(reset_time)
|
||||
|
||||
return response
|
||||
@@ -59,8 +59,6 @@ class OAuthState(BaseModel):
|
||||
code_verifier: Optional[str] = None
|
||||
scopes: list[str]
|
||||
"""Unix timestamp (seconds) indicating when this OAuth state expires"""
|
||||
credential_id: Optional[str] = None
|
||||
"""If set, this OAuth flow upgrades an existing credential's scopes."""
|
||||
|
||||
|
||||
class UserMetadata(BaseModel):
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from expiringdict import ExpiringDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster
|
||||
from redis.asyncio.lock import Lock as AsyncRedisLock
|
||||
|
||||
AsyncRedisLike = Union[AsyncRedis, AsyncRedisCluster]
|
||||
|
||||
|
||||
class AsyncRedisKeyedMutex:
|
||||
"""
|
||||
@@ -20,7 +17,7 @@ class AsyncRedisKeyedMutex:
|
||||
in case the key is not unlocked for a specified duration, to prevent memory leaks.
|
||||
"""
|
||||
|
||||
def __init__(self, redis: "AsyncRedisLike", timeout: int | None = 60):
|
||||
def __init__(self, redis: "AsyncRedis", timeout: int | None = 60):
|
||||
self.redis = redis
|
||||
self.timeout = timeout
|
||||
self.locks: dict[Any, "AsyncRedisLock"] = ExpiringDict(
|
||||
|
||||
@@ -37,23 +37,6 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||
|
||||
# Web Push (VAPID) — generate with: poetry run python -c "
|
||||
# from py_vapid import Vapid; import base64
|
||||
# from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
|
||||
# v = Vapid(); v.generate_keys()
|
||||
# raw_priv = v.private_key.private_numbers().private_value.to_bytes(32, 'big')
|
||||
# print('VAPID_PRIVATE_KEY=' + base64.urlsafe_b64encode(raw_priv).rstrip(b'=').decode())
|
||||
# raw_pub = v.public_key.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint)
|
||||
# print('VAPID_PUBLIC_KEY=' + base64.urlsafe_b64encode(raw_pub).rstrip(b'=').decode())
|
||||
# "
|
||||
# Dev-only keypair below — DO NOT use in staging/production. Regenerate
|
||||
# your own with the snippet above before any non-local deployment.
|
||||
VAPID_PRIVATE_KEY=17hBPdSdn6TR_yAgQxA0TjTcvRj3Lf6znHnASZ4rOKc
|
||||
VAPID_PUBLIC_KEY=BBg49iVTWthVbRYphwmZNvZyiSJDqtSO4nmLxDzLKe3Oo9jbtu0Usa14xX4HQQNLUeiEfzD42zWSlrvY1PR12bs
|
||||
# Per RFC 8292 push services use this in 410 Gone reports; set to a real
|
||||
# mailbox in production. Defaults to a placeholder for local dev.
|
||||
VAPID_CLAIM_EMAIL=mailto:dev@example.com
|
||||
|
||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||
# Platform URLs (set these for webhooks and OAuth to work)
|
||||
PLATFORM_BASE_URL=http://localhost:8000
|
||||
@@ -196,13 +179,6 @@ MEM0_API_KEY=
|
||||
OPENWEATHERMAP_API_KEY=
|
||||
GOOGLE_MAPS_API_KEY=
|
||||
|
||||
# Platform Bot Linking
|
||||
PLATFORM_LINK_BASE_URL=http://localhost:3000/link
|
||||
|
||||
# CoPilot chat-platform bridge (Discord/Telegram/Slack)
|
||||
# Uses FRONTEND_BASE_URL (above) for link confirmation pages.
|
||||
AUTOPILOT_BOT_DISCORD_TOKEN=
|
||||
|
||||
# Communication Services
|
||||
DISCORD_BOT_TOKEN=
|
||||
MEDIUM_API_KEY=
|
||||
|
||||
@@ -1,44 +1,14 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Awaitable, Callable, Dict, Optional, Set
|
||||
from typing import Dict, Set
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
from redis.asyncio.client import PubSub as AsyncPubSub
|
||||
from redis.exceptions import MovedError, RedisError, ResponseError
|
||||
from starlette.websockets import WebSocketState
|
||||
from fastapi import WebSocket
|
||||
|
||||
from backend.api.model import WSMessage, WSMethod
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.event_bus import _assert_no_wildcard
|
||||
from backend.api.model import NotificationPayload, WSMessage, WSMethod
|
||||
from backend.data.execution import (
|
||||
ExecutionEventType,
|
||||
exec_channel,
|
||||
get_graph_execution_meta,
|
||||
graph_all_channel,
|
||||
GraphExecutionEvent,
|
||||
NodeExecutionEvent,
|
||||
)
|
||||
from backend.data.notification_bus import NotificationEvent
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_settings = Settings()
|
||||
|
||||
|
||||
def _is_ws_close_race(exc: BaseException, websocket: WebSocket) -> bool:
|
||||
"""A SPUBLISH→WS send racing with WS close — benign, drop quietly."""
|
||||
if isinstance(exc, WebSocketDisconnect):
|
||||
return True
|
||||
if (
|
||||
getattr(websocket, "application_state", None) == WebSocketState.DISCONNECTED
|
||||
or getattr(websocket, "client_state", None) == WebSocketState.DISCONNECTED
|
||||
):
|
||||
return True
|
||||
if isinstance(exc, RuntimeError) and "close message has been sent" in str(exc):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
|
||||
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
|
||||
@@ -46,379 +16,128 @@ _EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
|
||||
}
|
||||
|
||||
|
||||
def event_bus_channel(channel_key: str) -> str:
|
||||
"""Prefix a channel key with the execution event bus name."""
|
||||
return f"{_settings.config.execution_event_bus_name}/{channel_key}"
|
||||
|
||||
|
||||
def _notification_bus_channel(user_id: str) -> str:
|
||||
"""Return the full sharded channel name for a user's notifications."""
|
||||
return f"{_settings.config.notification_event_bus_name}/{user_id}"
|
||||
|
||||
|
||||
MessageHandler = Callable[[Optional[bytes | str]], Awaitable[None]]
|
||||
|
||||
|
||||
def _is_moved_error(exc: BaseException) -> bool:
|
||||
"""A MOVED redirect — slot migration mid-stream; pump should reconnect."""
|
||||
if isinstance(exc, MovedError):
|
||||
return True
|
||||
if isinstance(exc, ResponseError) and str(exc).startswith("MOVED "):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Reconnect tunables for shard-failover during pubsub.listen().
|
||||
_PUMP_RECONNECT_DEADLINE_S = 60.0
|
||||
_PUMP_RECONNECT_BACKOFF_INITIAL_S = 0.5
|
||||
_PUMP_RECONNECT_BACKOFF_MAX_S = 8.0
|
||||
|
||||
|
||||
class _Subscription:
|
||||
"""One SSUBSCRIBE lifecycle bound to a WebSocket, pinned to the owning shard."""
|
||||
|
||||
def __init__(self, full_channel: str) -> None:
|
||||
_assert_no_wildcard(full_channel)
|
||||
self.full_channel = full_channel
|
||||
self._client: AsyncRedis | None = None
|
||||
self._pubsub: AsyncPubSub | None = None
|
||||
self._task: asyncio.Task | None = None
|
||||
|
||||
async def start(self, on_message: MessageHandler) -> None:
|
||||
await self._open_pubsub()
|
||||
self._task = asyncio.create_task(self._pump(on_message))
|
||||
|
||||
async def _open_pubsub(self) -> None:
|
||||
"""(Re)establish the sharded pubsub connection + SSUBSCRIBE."""
|
||||
self._client = await redis.connect_sharded_pubsub_async(self.full_channel)
|
||||
self._pubsub = self._client.pubsub()
|
||||
await self._pubsub.execute_command("SSUBSCRIBE", self.full_channel)
|
||||
# redis-py 6.x async PubSub.listen() exits when ``channels`` is
|
||||
# empty; raw SSUBSCRIBE doesn't populate it, so do it ourselves.
|
||||
self._pubsub.channels[self.full_channel] = None # type: ignore[index]
|
||||
|
||||
async def _close_pubsub_quietly(self) -> None:
|
||||
"""Best-effort teardown before reconnect — never raises."""
|
||||
if self._pubsub is not None:
|
||||
try:
|
||||
await self._pubsub.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._pubsub = None
|
||||
if self._client is not None:
|
||||
try:
|
||||
await self._client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._client = None
|
||||
|
||||
async def _pump(self, on_message: MessageHandler) -> None:
|
||||
if self._pubsub is None:
|
||||
return
|
||||
backoff = _PUMP_RECONNECT_BACKOFF_INITIAL_S
|
||||
deadline = time.monotonic() + _PUMP_RECONNECT_DEADLINE_S
|
||||
while True:
|
||||
pubsub = self._pubsub
|
||||
if pubsub is None:
|
||||
return
|
||||
needs_reconnect = False
|
||||
try:
|
||||
async for message in pubsub.listen():
|
||||
msg_type = message.get("type")
|
||||
# Server-pushed sunsubscribe: slot ownership changed and
|
||||
# Redis revoked our SSUBSCRIBE without dropping the TCP.
|
||||
# Treat as a reconnect trigger so we re-resolve the shard.
|
||||
if msg_type == "sunsubscribe":
|
||||
needs_reconnect = True
|
||||
break
|
||||
if msg_type not in ("smessage", "message", "pmessage"):
|
||||
continue
|
||||
# Successful read resets the reconnect budget.
|
||||
backoff = _PUMP_RECONNECT_BACKOFF_INITIAL_S
|
||||
deadline = time.monotonic() + _PUMP_RECONNECT_DEADLINE_S
|
||||
try:
|
||||
await on_message(message.get("data"))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Websocket message-handler failed for channel %s",
|
||||
self.full_channel,
|
||||
)
|
||||
if not needs_reconnect:
|
||||
# listen() exited cleanly (channels emptied) — pump is done.
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except (ConnectionError, RedisError) as exc:
|
||||
if isinstance(exc, ResponseError) and not _is_moved_error(exc):
|
||||
logger.exception(
|
||||
"Pubsub pump crashed on non-retryable ResponseError for %s",
|
||||
self.full_channel,
|
||||
)
|
||||
return
|
||||
if time.monotonic() > deadline:
|
||||
logger.exception(
|
||||
"Pubsub pump giving up after reconnect deadline for %s",
|
||||
self.full_channel,
|
||||
)
|
||||
return
|
||||
logger.warning(
|
||||
"Pubsub pump reconnecting for %s after %s: %s",
|
||||
self.full_channel,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Pubsub pump crashed for %s", self.full_channel)
|
||||
return
|
||||
|
||||
# Either a retryable error was raised, or the server pushed a
|
||||
# sunsubscribe — close the stale pubsub and reopen against the
|
||||
# (possibly migrated) shard.
|
||||
await self._close_pubsub_quietly()
|
||||
await asyncio.sleep(backoff)
|
||||
backoff = min(backoff * 2, _PUMP_RECONNECT_BACKOFF_MAX_S)
|
||||
try:
|
||||
await self._open_pubsub()
|
||||
except (ConnectionError, RedisError) as reopen_exc:
|
||||
logger.warning(
|
||||
"Pubsub pump reopen failed for %s: %s",
|
||||
self.full_channel,
|
||||
reopen_exc,
|
||||
)
|
||||
# Loop again — deadline check will eventually exit.
|
||||
continue
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._task is not None:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
self._task = None
|
||||
if self._pubsub is not None:
|
||||
try:
|
||||
await self._pubsub.execute_command("SUNSUBSCRIBE", self.full_channel)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"SUNSUBSCRIBE failed for %s", self.full_channel, exc_info=True
|
||||
)
|
||||
try:
|
||||
await self._pubsub.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._pubsub = None
|
||||
if self._client is not None:
|
||||
try:
|
||||
await self._client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._client = None
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: Set[WebSocket] = set()
|
||||
# channel_key → sockets subscribed (public channel keys, not raw Redis channels)
|
||||
self.subscriptions: Dict[str, Set[WebSocket]] = {}
|
||||
# websocket → {channel_key: _Subscription}
|
||||
self._ws_subs: Dict[WebSocket, Dict[str, _Subscription]] = {}
|
||||
# websocket → notification subscription
|
||||
self._ws_notifications: Dict[WebSocket, _Subscription] = {}
|
||||
self.user_connections: Dict[str, Set[WebSocket]] = {}
|
||||
|
||||
async def connect_socket(self, websocket: WebSocket, *, user_id: str):
|
||||
await websocket.accept()
|
||||
self.active_connections.add(websocket)
|
||||
self._ws_subs.setdefault(websocket, {})
|
||||
await self._start_notification_subscription(websocket, user_id=user_id)
|
||||
if user_id not in self.user_connections:
|
||||
self.user_connections[user_id] = set()
|
||||
self.user_connections[user_id].add(websocket)
|
||||
|
||||
async def disconnect_socket(self, websocket: WebSocket, *, user_id: str):
|
||||
def disconnect_socket(self, websocket: WebSocket, *, user_id: str):
|
||||
self.active_connections.discard(websocket)
|
||||
# Stop SSUBSCRIBE pumps before dropping bookkeeping to avoid leaks.
|
||||
subs = self._ws_subs.pop(websocket, {})
|
||||
for sub in subs.values():
|
||||
await sub.stop()
|
||||
notif_sub = self._ws_notifications.pop(websocket, None)
|
||||
if notif_sub is not None:
|
||||
await notif_sub.stop()
|
||||
for channel_key, subscribers in list(self.subscriptions.items()):
|
||||
for subscribers in self.subscriptions.values():
|
||||
subscribers.discard(websocket)
|
||||
if not subscribers:
|
||||
self.subscriptions.pop(channel_key, None)
|
||||
user_conns = self.user_connections.get(user_id)
|
||||
if user_conns is not None:
|
||||
user_conns.discard(websocket)
|
||||
if not user_conns:
|
||||
self.user_connections.pop(user_id, None)
|
||||
|
||||
async def subscribe_graph_exec(
|
||||
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
|
||||
) -> str:
|
||||
# Hash-tagged channel needs graph_id; resolve once per subscribe.
|
||||
meta = await get_graph_execution_meta(user_id, graph_exec_id)
|
||||
if meta is None:
|
||||
raise ValueError(
|
||||
f"graph_exec #{graph_exec_id} not found for user #{user_id}"
|
||||
)
|
||||
channel_key = graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id)
|
||||
full_channel = event_bus_channel(
|
||||
exec_channel(user_id, meta.graph_id, graph_exec_id)
|
||||
return await self._subscribe(
|
||||
_graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id), websocket
|
||||
)
|
||||
await self._open_subscription(websocket, channel_key, full_channel)
|
||||
return channel_key
|
||||
|
||||
async def subscribe_graph_execs(
|
||||
self, *, user_id: str, graph_id: str, websocket: WebSocket
|
||||
) -> str:
|
||||
channel_key = _graph_execs_channel_key(user_id, graph_id=graph_id)
|
||||
full_channel = event_bus_channel(graph_all_channel(user_id, graph_id))
|
||||
await self._open_subscription(websocket, channel_key, full_channel)
|
||||
return channel_key
|
||||
return await self._subscribe(
|
||||
_graph_execs_channel_key(user_id, graph_id=graph_id), websocket
|
||||
)
|
||||
|
||||
async def unsubscribe_graph_exec(
|
||||
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
|
||||
) -> str | None:
|
||||
channel_key = graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id)
|
||||
return await self._close_subscription(websocket, channel_key)
|
||||
return await self._unsubscribe(
|
||||
_graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id), websocket
|
||||
)
|
||||
|
||||
async def unsubscribe_graph_execs(
|
||||
self, *, user_id: str, graph_id: str, websocket: WebSocket
|
||||
) -> str | None:
|
||||
channel_key = _graph_execs_channel_key(user_id, graph_id=graph_id)
|
||||
return await self._close_subscription(websocket, channel_key)
|
||||
return await self._unsubscribe(
|
||||
_graph_execs_channel_key(user_id, graph_id=graph_id), websocket
|
||||
)
|
||||
|
||||
async def _open_subscription(
|
||||
self, websocket: WebSocket, channel_key: str, full_channel: str
|
||||
) -> None:
|
||||
self.subscriptions.setdefault(channel_key, set()).add(websocket)
|
||||
per_ws = self._ws_subs.setdefault(websocket, {})
|
||||
if channel_key in per_ws:
|
||||
return
|
||||
sub = _Subscription(full_channel)
|
||||
async def send_execution_update(
|
||||
self, exec_event: GraphExecutionEvent | NodeExecutionEvent
|
||||
) -> int:
|
||||
graph_exec_id = (
|
||||
exec_event.id
|
||||
if isinstance(exec_event, GraphExecutionEvent)
|
||||
else exec_event.graph_exec_id
|
||||
)
|
||||
|
||||
async def on_message(data: Optional[bytes | str]) -> None:
|
||||
await self._forward_exec_event(websocket, channel_key, data)
|
||||
n_sent = 0
|
||||
|
||||
await sub.start(on_message)
|
||||
per_ws[channel_key] = sub
|
||||
|
||||
async def _close_subscription(
|
||||
self, websocket: WebSocket, channel_key: str
|
||||
) -> str | None:
|
||||
subscribers = self.subscriptions.get(channel_key)
|
||||
if subscribers is None:
|
||||
return None
|
||||
subscribers.discard(websocket)
|
||||
if not subscribers:
|
||||
self.subscriptions.pop(channel_key, None)
|
||||
per_ws = self._ws_subs.get(websocket)
|
||||
if per_ws and channel_key in per_ws:
|
||||
sub = per_ws.pop(channel_key)
|
||||
await sub.stop()
|
||||
return channel_key
|
||||
|
||||
async def _forward_exec_event(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
channel_key: str,
|
||||
raw_payload: Optional[bytes | str],
|
||||
) -> None:
|
||||
if raw_payload is None:
|
||||
return
|
||||
# Unwrap the `_EventPayloadWrapper` envelope, then re-wrap as a WS message.
|
||||
try:
|
||||
wrapper = (
|
||||
raw_payload.decode()
|
||||
if isinstance(raw_payload, (bytes, bytearray))
|
||||
else raw_payload
|
||||
channels: set[str] = {
|
||||
# Send update to listeners for this graph execution
|
||||
_graph_exec_channel_key(exec_event.user_id, graph_exec_id=graph_exec_id)
|
||||
}
|
||||
if isinstance(exec_event, GraphExecutionEvent):
|
||||
# Send update to listeners for all executions of this graph
|
||||
channels.add(
|
||||
_graph_execs_channel_key(
|
||||
exec_event.user_id, graph_id=exec_event.graph_id
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to decode pubsub payload on %s", channel_key, exc_info=True
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
parsed = json.loads(wrapper)
|
||||
event_data = parsed.get("payload")
|
||||
if not isinstance(event_data, dict):
|
||||
return
|
||||
event_type = event_data.get("event_type")
|
||||
method = _EVENT_TYPE_TO_METHOD_MAP.get(ExecutionEventType(event_type))
|
||||
if method is None:
|
||||
return
|
||||
for channel in channels.intersection(self.subscriptions.keys()):
|
||||
message = WSMessage(
|
||||
method=method,
|
||||
channel=channel_key,
|
||||
data=event_data,
|
||||
method=_EVENT_TYPE_TO_METHOD_MAP[exec_event.event_type],
|
||||
channel=channel,
|
||||
data=exec_event.model_dump(),
|
||||
).model_dump_json()
|
||||
await websocket.send_text(message)
|
||||
except Exception as e:
|
||||
if _is_ws_close_race(e, websocket):
|
||||
logger.debug("Dropped exec event on closed WS for %s", channel_key)
|
||||
return
|
||||
logger.exception("Failed to forward exec event on %s", channel_key)
|
||||
for connection in self.subscriptions[channel]:
|
||||
await connection.send_text(message)
|
||||
n_sent += 1
|
||||
|
||||
async def _start_notification_subscription(
|
||||
self, websocket: WebSocket, *, user_id: str
|
||||
) -> None:
|
||||
full_channel = _notification_bus_channel(user_id)
|
||||
sub = _Subscription(full_channel)
|
||||
return n_sent
|
||||
|
||||
async def on_message(data: Optional[bytes | str]) -> None:
|
||||
await self._forward_notification(websocket, user_id, data)
|
||||
|
||||
try:
|
||||
await sub.start(on_message)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to open notification SSUBSCRIBE for user=%s", user_id
|
||||
)
|
||||
return
|
||||
self._ws_notifications[websocket] = sub
|
||||
|
||||
async def _forward_notification(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
raw_payload: Optional[bytes | str],
|
||||
) -> None:
|
||||
if raw_payload is None:
|
||||
return
|
||||
try:
|
||||
wrapper_json = (
|
||||
raw_payload.decode()
|
||||
if isinstance(raw_payload, (bytes, bytearray))
|
||||
else raw_payload
|
||||
)
|
||||
parsed = json.loads(wrapper_json)
|
||||
inner = parsed.get("payload") if isinstance(parsed, dict) else None
|
||||
if not isinstance(inner, dict):
|
||||
return
|
||||
event = NotificationEvent.model_validate(inner)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse notification payload for user=%s",
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
# Defense in depth against cross-user payloads.
|
||||
if event.user_id != user_id:
|
||||
return
|
||||
async def send_notification(
|
||||
self, *, user_id: str, payload: NotificationPayload
|
||||
) -> int:
|
||||
"""Send a notification to all websocket connections belonging to a user."""
|
||||
message = WSMessage(
|
||||
method=WSMethod.NOTIFICATION,
|
||||
data=event.payload.model_dump(),
|
||||
data=payload.model_dump(),
|
||||
).model_dump_json()
|
||||
try:
|
||||
await websocket.send_text(message)
|
||||
except Exception as e:
|
||||
if _is_ws_close_race(e, websocket):
|
||||
logger.debug("Dropped notification on closed WS for user=%s", user_id)
|
||||
return
|
||||
logger.warning(
|
||||
"Failed to deliver notification to WS for user=%s",
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
connections = tuple(self.user_connections.get(user_id, set()))
|
||||
if not connections:
|
||||
return 0
|
||||
|
||||
await asyncio.gather(
|
||||
*(connection.send_text(message) for connection in connections),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
return len(connections)
|
||||
|
||||
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
|
||||
if channel_key not in self.subscriptions:
|
||||
self.subscriptions[channel_key] = set()
|
||||
self.subscriptions[channel_key].add(websocket)
|
||||
return channel_key
|
||||
|
||||
async def _unsubscribe(self, channel_key: str, websocket: WebSocket) -> str | None:
|
||||
if channel_key in self.subscriptions:
|
||||
self.subscriptions[channel_key].discard(websocket)
|
||||
if not self.subscriptions[channel_key]:
|
||||
del self.subscriptions[channel_key]
|
||||
return channel_key
|
||||
return None
|
||||
|
||||
|
||||
def graph_exec_channel_key(user_id: str, *, graph_exec_id: str) -> str:
|
||||
def _graph_exec_channel_key(user_id: str, *, graph_exec_id: str) -> str:
|
||||
return f"{user_id}|graph_exec#{graph_exec_id}"
|
||||
|
||||
|
||||
|
||||
@@ -1,386 +0,0 @@
|
||||
"""ConnectionManager integration over the live 3-shard Redis cluster:
|
||||
SSUBSCRIBE → SPUBLISH → WebSocket forwarding with no Redis mocks. Skips
|
||||
when the cluster is unreachable."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import WebSocket
|
||||
|
||||
import backend.data.redis_client as redis_client
|
||||
from backend.api.conn_manager import (
|
||||
ConnectionManager,
|
||||
_graph_execs_channel_key,
|
||||
event_bus_channel,
|
||||
graph_exec_channel_key,
|
||||
)
|
||||
from backend.api.model import WSMethod
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionEvent,
|
||||
GraphExecutionMeta,
|
||||
NodeExecutionEvent,
|
||||
exec_channel,
|
||||
graph_all_channel,
|
||||
)
|
||||
|
||||
|
||||
def _has_live_cluster() -> bool:
|
||||
try:
|
||||
c = redis_client.connect()
|
||||
except Exception: # noqa: BLE001 — any connect failure → skip
|
||||
return False
|
||||
try:
|
||||
c.close()
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not _has_live_cluster(),
|
||||
reason="local redis cluster not reachable; skip conn_manager integration",
|
||||
)
|
||||
|
||||
|
||||
def _meta(user_id: str, graph_id: str, graph_exec_id: str) -> GraphExecutionMeta:
|
||||
"""Build a minimal GraphExecutionMeta for ``subscribe_graph_exec`` to use."""
|
||||
return GraphExecutionMeta(
|
||||
id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=1,
|
||||
inputs=None,
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
preset_id=None,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(tz=timezone.utc),
|
||||
ended_at=None,
|
||||
stats=GraphExecutionMeta.Stats(),
|
||||
)
|
||||
|
||||
|
||||
def _node_event_payload(
|
||||
*, user_id: str, graph_id: str, graph_exec_id: str, marker: str
|
||||
) -> bytes:
|
||||
"""Wire-format a NodeExecutionEvent the way RedisExecutionEventBus would."""
|
||||
inner = NodeExecutionEvent(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=1,
|
||||
graph_exec_id=graph_exec_id,
|
||||
node_exec_id=f"node-exec-{marker}",
|
||||
node_id="node-1",
|
||||
block_id="block-1",
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
input_data={"in": marker},
|
||||
output_data={"out": [marker]},
|
||||
add_time=datetime.now(tz=timezone.utc),
|
||||
queue_time=None,
|
||||
start_time=datetime.now(tz=timezone.utc),
|
||||
end_time=datetime.now(tz=timezone.utc),
|
||||
).model_dump(mode="json")
|
||||
return json.dumps({"payload": inner}).encode()
|
||||
|
||||
|
||||
def _graph_event_payload(
|
||||
*, user_id: str, graph_id: str, graph_exec_id: str, marker: str
|
||||
) -> bytes:
|
||||
inner = GraphExecutionEvent(
|
||||
id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=1,
|
||||
preset_id=None,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
started_at=datetime.now(tz=timezone.utc),
|
||||
ended_at=datetime.now(tz=timezone.utc),
|
||||
stats=GraphExecutionEvent.Stats(
|
||||
cost=0,
|
||||
duration=1.0,
|
||||
node_exec_time=0.5,
|
||||
node_exec_count=1,
|
||||
),
|
||||
inputs={"x": marker},
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
outputs={"y": [marker]},
|
||||
).model_dump(mode="json")
|
||||
return json.dumps({"payload": inner}).encode()
|
||||
|
||||
|
||||
async def _wait_until(predicate, timeout: float = 5.0, interval: float = 0.05) -> bool:
|
||||
"""Poll ``predicate()`` until truthy or timeout — used to wait for pubsub."""
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
if predicate():
|
||||
return True
|
||||
await asyncio.sleep(interval)
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_clients_get_independent_ssubscribes_on_right_shards(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""Two WS clients on different graph_exec_ids each receive ONLY their
|
||||
own publish, even when the channels land on different shards."""
|
||||
user_id = "user-conn-int-1"
|
||||
graph_a = f"graph-a-{uuid4().hex[:8]}"
|
||||
graph_b = f"graph-b-{uuid4().hex[:8]}"
|
||||
exec_a = f"exec-a-{uuid4().hex[:8]}"
|
||||
exec_b = f"exec-b-{uuid4().hex[:8]}"
|
||||
|
||||
# Stub Prisma lookup so tests don't need a DB.
|
||||
async def _fake_meta(_uid, gex_id):
|
||||
return _meta(user_id, graph_a if gex_id == exec_a else graph_b, gex_id)
|
||||
|
||||
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
|
||||
|
||||
cm = ConnectionManager()
|
||||
ws_a: AsyncMock = AsyncMock(spec=WebSocket)
|
||||
ws_b: AsyncMock = AsyncMock(spec=WebSocket)
|
||||
sent_a: list[str] = []
|
||||
sent_b: list[str] = []
|
||||
ws_a.send_text = AsyncMock(side_effect=lambda m: sent_a.append(m))
|
||||
ws_b.send_text = AsyncMock(side_effect=lambda m: sent_b.append(m))
|
||||
|
||||
redis_client.get_redis.cache_clear()
|
||||
cluster = redis_client.get_redis()
|
||||
|
||||
try:
|
||||
await cm.subscribe_graph_exec(
|
||||
user_id=user_id, graph_exec_id=exec_a, websocket=ws_a
|
||||
)
|
||||
await cm.subscribe_graph_exec(
|
||||
user_id=user_id, graph_exec_id=exec_b, websocket=ws_b
|
||||
)
|
||||
# Let SSUBSCRIBE settle on each shard.
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Publish to each per-exec channel.
|
||||
chan_a = event_bus_channel(exec_channel(user_id, graph_a, exec_a))
|
||||
chan_b = event_bus_channel(exec_channel(user_id, graph_b, exec_b))
|
||||
cluster.spublish(
|
||||
chan_a,
|
||||
_node_event_payload(
|
||||
user_id=user_id,
|
||||
graph_id=graph_a,
|
||||
graph_exec_id=exec_a,
|
||||
marker="A",
|
||||
).decode(),
|
||||
)
|
||||
cluster.spublish(
|
||||
chan_b,
|
||||
_node_event_payload(
|
||||
user_id=user_id,
|
||||
graph_id=graph_b,
|
||||
graph_exec_id=exec_b,
|
||||
marker="B",
|
||||
).decode(),
|
||||
)
|
||||
|
||||
delivered = await _wait_until(lambda: sent_a and sent_b, timeout=5.0)
|
||||
assert delivered, f"timeout: sent_a={sent_a!r} sent_b={sent_b!r}"
|
||||
|
||||
msg_a = json.loads(sent_a[0])
|
||||
msg_b = json.loads(sent_b[0])
|
||||
assert msg_a["channel"] == graph_exec_channel_key(user_id, graph_exec_id=exec_a)
|
||||
assert msg_b["channel"] == graph_exec_channel_key(user_id, graph_exec_id=exec_b)
|
||||
assert msg_a["data"]["graph_exec_id"] == exec_a
|
||||
assert msg_b["data"]["graph_exec_id"] == exec_b
|
||||
# No cross-talk: each socket got exactly one message.
|
||||
assert len(sent_a) == 1 and len(sent_b) == 1
|
||||
finally:
|
||||
await cm.disconnect_socket(ws_a, user_id=user_id)
|
||||
await cm.disconnect_socket(ws_b, user_id=user_id)
|
||||
redis_client.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregate_channel_receives_per_exec_publishes(monkeypatch) -> None:
|
||||
"""A subscriber on the ``graph_execs`` aggregate channel must receive the
|
||||
GraphExecutionEvent published to the ``/all`` channel — even though
|
||||
per-exec events go to a different channel."""
|
||||
user_id = "user-conn-int-2"
|
||||
graph_id = f"graph-{uuid4().hex[:8]}"
|
||||
exec_id = f"exec-{uuid4().hex[:8]}"
|
||||
|
||||
async def _fake_meta(_uid, gex_id):
|
||||
return _meta(user_id, graph_id, gex_id)
|
||||
|
||||
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
|
||||
|
||||
cm = ConnectionManager()
|
||||
ws_agg: AsyncMock = AsyncMock(spec=WebSocket)
|
||||
ws_per: AsyncMock = AsyncMock(spec=WebSocket)
|
||||
sent_agg: list[str] = []
|
||||
sent_per: list[str] = []
|
||||
ws_agg.send_text = AsyncMock(side_effect=lambda m: sent_agg.append(m))
|
||||
ws_per.send_text = AsyncMock(side_effect=lambda m: sent_per.append(m))
|
||||
|
||||
redis_client.get_redis.cache_clear()
|
||||
cluster = redis_client.get_redis()
|
||||
|
||||
try:
|
||||
await cm.subscribe_graph_execs(
|
||||
user_id=user_id, graph_id=graph_id, websocket=ws_agg
|
||||
)
|
||||
await cm.subscribe_graph_exec(
|
||||
user_id=user_id, graph_exec_id=exec_id, websocket=ws_per
|
||||
)
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# The eventbus publishes the same event to both channels — replicate.
|
||||
chan_per = event_bus_channel(exec_channel(user_id, graph_id, exec_id))
|
||||
chan_all = event_bus_channel(graph_all_channel(user_id, graph_id))
|
||||
payload = _graph_event_payload(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=exec_id,
|
||||
marker="agg",
|
||||
).decode()
|
||||
cluster.spublish(chan_per, payload)
|
||||
cluster.spublish(chan_all, payload)
|
||||
|
||||
delivered = await _wait_until(lambda: sent_agg and sent_per, timeout=5.0)
|
||||
assert delivered, f"sent_agg={sent_agg!r} sent_per={sent_per!r}"
|
||||
agg_msg = json.loads(sent_agg[0])
|
||||
per_msg = json.loads(sent_per[0])
|
||||
# Aggregate subscriber's channel key is the per-graph executions key.
|
||||
assert agg_msg["channel"] == _graph_execs_channel_key(
|
||||
user_id, graph_id=graph_id
|
||||
)
|
||||
assert per_msg["channel"] == graph_exec_channel_key(
|
||||
user_id, graph_exec_id=exec_id
|
||||
)
|
||||
assert agg_msg["method"] == WSMethod.GRAPH_EXECUTION_EVENT.value
|
||||
finally:
|
||||
await cm.disconnect_socket(ws_agg, user_id=user_id)
|
||||
await cm.disconnect_socket(ws_per, user_id=user_id)
|
||||
redis_client.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_unsubscribes_and_drops_future_publishes(monkeypatch) -> None:
|
||||
"""After ``disconnect_socket`` runs, a subsequent SPUBLISH must NOT reach
|
||||
the dead websocket — exercises the SUNSUBSCRIBE + bookkeeping cleanup."""
|
||||
user_id = "user-conn-int-3"
|
||||
graph_id = f"graph-{uuid4().hex[:8]}"
|
||||
exec_id = f"exec-{uuid4().hex[:8]}"
|
||||
|
||||
async def _fake_meta(_uid, gex_id):
|
||||
return _meta(user_id, graph_id, gex_id)
|
||||
|
||||
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
|
||||
|
||||
cm = ConnectionManager()
|
||||
ws: AsyncMock = AsyncMock(spec=WebSocket)
|
||||
sent: list[str] = []
|
||||
ws.send_text = AsyncMock(side_effect=lambda m: sent.append(m))
|
||||
|
||||
redis_client.get_redis.cache_clear()
|
||||
cluster = redis_client.get_redis()
|
||||
chan = event_bus_channel(exec_channel(user_id, graph_id, exec_id))
|
||||
payload = _node_event_payload(
|
||||
user_id=user_id, graph_id=graph_id, graph_exec_id=exec_id, marker="live"
|
||||
).decode()
|
||||
|
||||
try:
|
||||
await cm.subscribe_graph_exec(
|
||||
user_id=user_id, graph_exec_id=exec_id, websocket=ws
|
||||
)
|
||||
await asyncio.sleep(0.15)
|
||||
|
||||
# First publish — must reach the socket.
|
||||
cluster.spublish(chan, payload)
|
||||
delivered = await _wait_until(lambda: bool(sent), timeout=5.0)
|
||||
assert delivered
|
||||
assert len(sent) == 1
|
||||
|
||||
# Disconnect → SUNSUBSCRIBE + bookkeeping cleared.
|
||||
await cm.disconnect_socket(ws, user_id=user_id)
|
||||
# Pump cancellation may drain in-flight messages; wait for it.
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Channel bookkeeping must be gone.
|
||||
assert (
|
||||
graph_exec_channel_key(user_id, graph_exec_id=exec_id)
|
||||
not in cm.subscriptions
|
||||
)
|
||||
assert ws not in cm._ws_subs
|
||||
|
||||
# Second publish — must NOT reach the (already-disconnected) socket.
|
||||
cluster.spublish(
|
||||
chan,
|
||||
_node_event_payload(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=exec_id,
|
||||
marker="post-disconnect",
|
||||
).decode(),
|
||||
)
|
||||
await asyncio.sleep(0.5)
|
||||
# Still only the one pre-disconnect message.
|
||||
assert len(sent) == 1
|
||||
finally:
|
||||
redis_client.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_consumer_receives_all_events_without_loss(monkeypatch) -> None:
|
||||
"""Burst-publish many SPUBLISHes; assert every one reaches the subscriber
|
||||
in order — guards against drops/reorderings in the pubsub pump."""
|
||||
user_id = "user-conn-int-4"
|
||||
graph_id = f"graph-{uuid4().hex[:8]}"
|
||||
exec_id = f"exec-{uuid4().hex[:8]}"
|
||||
n_events = 100
|
||||
|
||||
async def _fake_meta(_uid, gex_id):
|
||||
return _meta(user_id, graph_id, gex_id)
|
||||
|
||||
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
|
||||
|
||||
cm = ConnectionManager()
|
||||
ws: AsyncMock = AsyncMock(spec=WebSocket)
|
||||
sent: list[str] = []
|
||||
ws.send_text = AsyncMock(side_effect=lambda m: sent.append(m))
|
||||
|
||||
redis_client.get_redis.cache_clear()
|
||||
cluster = redis_client.get_redis()
|
||||
chan = event_bus_channel(exec_channel(user_id, graph_id, exec_id))
|
||||
|
||||
try:
|
||||
await cm.subscribe_graph_exec(
|
||||
user_id=user_id, graph_exec_id=exec_id, websocket=ws
|
||||
)
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Burst-publish n_events without yielding to the pump.
|
||||
for i in range(n_events):
|
||||
cluster.spublish(
|
||||
chan,
|
||||
_node_event_payload(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=exec_id,
|
||||
marker=f"m{i}",
|
||||
).decode(),
|
||||
)
|
||||
|
||||
delivered = await _wait_until(
|
||||
lambda: len(sent) >= n_events, timeout=15.0, interval=0.1
|
||||
)
|
||||
assert delivered, f"only delivered {len(sent)}/{n_events}"
|
||||
|
||||
# Validate ordering — Redis pub/sub is FIFO per channel.
|
||||
markers = [json.loads(m)["data"]["input_data"]["in"] for m in sent[:n_events]]
|
||||
assert markers == [f"m{i}" for i in range(n_events)]
|
||||
finally:
|
||||
await cm.disconnect_socket(ws, user_id=user_id)
|
||||
redis_client.disconnect()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,932 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from autogpt_libs.auth import requires_admin_user
|
||||
from autogpt_libs.auth.models import User as AuthUser
|
||||
from fastapi import APIRouter, HTTPException, Security
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.admin.model import (
|
||||
AgentDiagnosticsResponse,
|
||||
ExecutionDiagnosticsResponse,
|
||||
)
|
||||
from backend.data.diagnostics import (
|
||||
FailedExecutionDetail,
|
||||
OrphanedScheduleDetail,
|
||||
RunningExecutionDetail,
|
||||
ScheduleDetail,
|
||||
ScheduleHealthMetrics,
|
||||
cleanup_all_stuck_queued_executions,
|
||||
cleanup_orphaned_executions_bulk,
|
||||
cleanup_orphaned_schedules_bulk,
|
||||
get_agent_diagnostics,
|
||||
get_all_orphaned_execution_ids,
|
||||
get_all_schedules_details,
|
||||
get_all_stuck_queued_execution_ids,
|
||||
get_execution_diagnostics,
|
||||
get_failed_executions_count,
|
||||
get_failed_executions_details,
|
||||
get_invalid_executions_details,
|
||||
get_long_running_executions_details,
|
||||
get_orphaned_executions_details,
|
||||
get_orphaned_schedules_details,
|
||||
get_running_executions_details,
|
||||
get_schedule_health_metrics,
|
||||
get_stuck_queued_executions_details,
|
||||
stop_all_long_running_executions,
|
||||
)
|
||||
from backend.data.execution import get_graph_executions
|
||||
from backend.executor.utils import add_graph_execution, stop_graph_execution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["diagnostics", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
class RunningExecutionsListResponse(BaseModel):
|
||||
"""Response model for list of running executions"""
|
||||
|
||||
executions: List[RunningExecutionDetail]
|
||||
total: int
|
||||
|
||||
|
||||
class FailedExecutionsListResponse(BaseModel):
|
||||
"""Response model for list of failed executions"""
|
||||
|
||||
executions: List[FailedExecutionDetail]
|
||||
total: int
|
||||
|
||||
|
||||
class StopExecutionRequest(BaseModel):
|
||||
"""Request model for stopping a single execution"""
|
||||
|
||||
execution_id: str
|
||||
|
||||
|
||||
class StopExecutionsRequest(BaseModel):
|
||||
"""Request model for stopping multiple executions"""
|
||||
|
||||
execution_ids: List[str]
|
||||
|
||||
|
||||
class StopExecutionResponse(BaseModel):
|
||||
"""Response model for stop execution operations"""
|
||||
|
||||
success: bool
|
||||
stopped_count: int = 0
|
||||
message: str
|
||||
|
||||
|
||||
class RequeueExecutionResponse(BaseModel):
|
||||
"""Response model for requeue execution operations"""
|
||||
|
||||
success: bool
|
||||
requeued_count: int = 0
|
||||
message: str
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions",
|
||||
response_model=ExecutionDiagnosticsResponse,
|
||||
summary="Get Execution Diagnostics",
|
||||
)
|
||||
async def get_execution_diagnostics_endpoint():
|
||||
"""
|
||||
Get comprehensive diagnostic information about execution status.
|
||||
|
||||
Returns all execution metrics including:
|
||||
- Current state (running, queued)
|
||||
- Orphaned executions (>24h old, likely not in executor)
|
||||
- Failure metrics (1h, 24h, rate)
|
||||
- Long-running detection (stuck >1h, >24h)
|
||||
- Stuck queued detection
|
||||
- Throughput metrics (completions/hour)
|
||||
- RabbitMQ queue depths
|
||||
"""
|
||||
logger.info("Getting execution diagnostics")
|
||||
|
||||
diagnostics = await get_execution_diagnostics()
|
||||
|
||||
response = ExecutionDiagnosticsResponse(
|
||||
running_executions=diagnostics.running_count,
|
||||
queued_executions_db=diagnostics.queued_db_count,
|
||||
queued_executions_rabbitmq=diagnostics.rabbitmq_queue_depth,
|
||||
cancel_queue_depth=diagnostics.cancel_queue_depth,
|
||||
orphaned_running=diagnostics.orphaned_running,
|
||||
orphaned_queued=diagnostics.orphaned_queued,
|
||||
failed_count_1h=diagnostics.failed_count_1h,
|
||||
failed_count_24h=diagnostics.failed_count_24h,
|
||||
failure_rate_24h=diagnostics.failure_rate_24h,
|
||||
stuck_running_24h=diagnostics.stuck_running_24h,
|
||||
stuck_running_1h=diagnostics.stuck_running_1h,
|
||||
oldest_running_hours=diagnostics.oldest_running_hours,
|
||||
stuck_queued_1h=diagnostics.stuck_queued_1h,
|
||||
queued_never_started=diagnostics.queued_never_started,
|
||||
invalid_queued_with_start=diagnostics.invalid_queued_with_start,
|
||||
invalid_running_without_start=diagnostics.invalid_running_without_start,
|
||||
completed_1h=diagnostics.completed_1h,
|
||||
completed_24h=diagnostics.completed_24h,
|
||||
throughput_per_hour=diagnostics.throughput_per_hour,
|
||||
timestamp=diagnostics.timestamp,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Execution diagnostics: running={diagnostics.running_count}, "
|
||||
f"queued_db={diagnostics.queued_db_count}, "
|
||||
f"orphaned={diagnostics.orphaned_running + diagnostics.orphaned_queued}, "
|
||||
f"failed_24h={diagnostics.failed_count_24h}"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/agents",
|
||||
response_model=AgentDiagnosticsResponse,
|
||||
summary="Get Agent Diagnostics",
|
||||
)
|
||||
async def get_agent_diagnostics_endpoint():
|
||||
"""
|
||||
Get diagnostic information about agents.
|
||||
|
||||
Returns:
|
||||
- agents_with_active_executions: Number of unique agents with running/queued executions
|
||||
- timestamp: Current timestamp
|
||||
"""
|
||||
logger.info("Getting agent diagnostics")
|
||||
|
||||
diagnostics = await get_agent_diagnostics()
|
||||
|
||||
response = AgentDiagnosticsResponse(
|
||||
agents_with_active_executions=diagnostics.agents_with_active_executions,
|
||||
timestamp=diagnostics.timestamp,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Agent diagnostics: with_active_executions={diagnostics.agents_with_active_executions}"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions/running",
|
||||
response_model=RunningExecutionsListResponse,
|
||||
summary="List Running Executions",
|
||||
)
|
||||
async def list_running_executions(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Get detailed list of running and queued executions (recent, likely active).
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return (default 100)
|
||||
offset: Number of executions to skip (default 0)
|
||||
|
||||
Returns:
|
||||
List of running executions with details
|
||||
"""
|
||||
logger.info(f"Listing running executions (limit={limit}, offset={offset})")
|
||||
|
||||
executions = await get_running_executions_details(limit=limit, offset=offset)
|
||||
|
||||
# Get total count for pagination
|
||||
diagnostics = await get_execution_diagnostics()
|
||||
total = diagnostics.running_count + diagnostics.queued_db_count
|
||||
|
||||
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions/orphaned",
|
||||
response_model=RunningExecutionsListResponse,
|
||||
summary="List Orphaned Executions",
|
||||
)
|
||||
async def list_orphaned_executions(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Get detailed list of orphaned executions (>24h old, likely not in executor).
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return (default 100)
|
||||
offset: Number of executions to skip (default 0)
|
||||
|
||||
Returns:
|
||||
List of orphaned executions with details
|
||||
"""
|
||||
logger.info(f"Listing orphaned executions (limit={limit}, offset={offset})")
|
||||
|
||||
executions = await get_orphaned_executions_details(limit=limit, offset=offset)
|
||||
|
||||
# Get total count for pagination
|
||||
diagnostics = await get_execution_diagnostics()
|
||||
total = diagnostics.orphaned_running + diagnostics.orphaned_queued
|
||||
|
||||
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions/failed",
|
||||
response_model=FailedExecutionsListResponse,
|
||||
summary="List Failed Executions",
|
||||
)
|
||||
async def list_failed_executions(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
hours: int = 24,
|
||||
):
|
||||
"""
|
||||
Get detailed list of failed executions.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return (default 100)
|
||||
offset: Number of executions to skip (default 0)
|
||||
hours: Number of hours to look back (default 24)
|
||||
|
||||
Returns:
|
||||
List of failed executions with error details
|
||||
"""
|
||||
logger.info(
|
||||
f"Listing failed executions (limit={limit}, offset={offset}, hours={hours})"
|
||||
)
|
||||
|
||||
executions = await get_failed_executions_details(
|
||||
limit=limit, offset=offset, hours=hours
|
||||
)
|
||||
|
||||
# Get total count for pagination
|
||||
# Always count actual total for given hours parameter
|
||||
total = await get_failed_executions_count(hours=hours)
|
||||
|
||||
return FailedExecutionsListResponse(executions=executions, total=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions/long-running",
|
||||
response_model=RunningExecutionsListResponse,
|
||||
summary="List Long-Running Executions",
|
||||
)
|
||||
async def list_long_running_executions(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Get detailed list of long-running executions (RUNNING status >24h).
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return (default 100)
|
||||
offset: Number of executions to skip (default 0)
|
||||
|
||||
Returns:
|
||||
List of long-running executions with details
|
||||
"""
|
||||
logger.info(f"Listing long-running executions (limit={limit}, offset={offset})")
|
||||
|
||||
executions = await get_long_running_executions_details(limit=limit, offset=offset)
|
||||
|
||||
# Get total count for pagination
|
||||
diagnostics = await get_execution_diagnostics()
|
||||
total = diagnostics.stuck_running_24h
|
||||
|
||||
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions/stuck-queued",
|
||||
response_model=RunningExecutionsListResponse,
|
||||
summary="List Stuck Queued Executions",
|
||||
)
|
||||
async def list_stuck_queued_executions(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Get detailed list of stuck queued executions (QUEUED >1h, never started).
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return (default 100)
|
||||
offset: Number of executions to skip (default 0)
|
||||
|
||||
Returns:
|
||||
List of stuck queued executions with details
|
||||
"""
|
||||
logger.info(f"Listing stuck queued executions (limit={limit}, offset={offset})")
|
||||
|
||||
executions = await get_stuck_queued_executions_details(limit=limit, offset=offset)
|
||||
|
||||
# Get total count for pagination
|
||||
diagnostics = await get_execution_diagnostics()
|
||||
total = diagnostics.stuck_queued_1h
|
||||
|
||||
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/executions/invalid",
|
||||
response_model=RunningExecutionsListResponse,
|
||||
summary="List Invalid Executions",
|
||||
)
|
||||
async def list_invalid_executions(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Get detailed list of executions in invalid states (READ-ONLY).
|
||||
|
||||
Invalid states indicate data corruption and require manual investigation:
|
||||
- QUEUED but has startedAt (impossible - can't start while queued)
|
||||
- RUNNING but no startedAt (impossible - can't run without starting)
|
||||
|
||||
⚠️ NO BULK ACTIONS PROVIDED - These need case-by-case investigation.
|
||||
|
||||
Each invalid execution likely has a different root cause (crashes, race conditions,
|
||||
DB corruption). Investigate the execution history and logs to determine appropriate
|
||||
action (manual cleanup, status fix, or leave as-is if system recovered).
|
||||
|
||||
Args:
|
||||
limit: Maximum number of executions to return (default 100)
|
||||
offset: Number of executions to skip (default 0)
|
||||
|
||||
Returns:
|
||||
List of invalid state executions with details
|
||||
"""
|
||||
logger.info(f"Listing invalid state executions (limit={limit}, offset={offset})")
|
||||
|
||||
executions = await get_invalid_executions_details(limit=limit, offset=offset)
|
||||
|
||||
# Get total count for pagination
|
||||
diagnostics = await get_execution_diagnostics()
|
||||
total = (
|
||||
diagnostics.invalid_queued_with_start
|
||||
+ diagnostics.invalid_running_without_start
|
||||
)
|
||||
|
||||
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/requeue",
|
||||
response_model=RequeueExecutionResponse,
|
||||
summary="Requeue Stuck Execution",
|
||||
)
|
||||
async def requeue_single_execution(
|
||||
request: StopExecutionRequest, # Reuse same request model (has execution_id)
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Requeue a stuck QUEUED execution (admin only).
|
||||
|
||||
Uses add_graph_execution with existing graph_exec_id to requeue.
|
||||
|
||||
⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits.
|
||||
|
||||
Args:
|
||||
request: Contains execution_id to requeue
|
||||
|
||||
Returns:
|
||||
Success status and message
|
||||
"""
|
||||
logger.info(f"Admin {user.user_id} requeueing execution {request.execution_id}")
|
||||
|
||||
# Get the execution (validation - must be QUEUED)
|
||||
executions = await get_graph_executions(
|
||||
graph_exec_id=request.execution_id,
|
||||
statuses=[AgentExecutionStatus.QUEUED],
|
||||
)
|
||||
|
||||
if not executions:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Execution not found or not in QUEUED status",
|
||||
)
|
||||
|
||||
execution = executions[0]
|
||||
|
||||
# Use add_graph_execution in requeue mode
|
||||
await add_graph_execution(
|
||||
graph_id=execution.graph_id,
|
||||
user_id=execution.user_id,
|
||||
graph_version=execution.graph_version,
|
||||
graph_exec_id=request.execution_id, # Requeue existing execution
|
||||
)
|
||||
|
||||
return RequeueExecutionResponse(
|
||||
success=True,
|
||||
requeued_count=1,
|
||||
message="Execution requeued successfully",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/requeue-bulk",
|
||||
response_model=RequeueExecutionResponse,
|
||||
summary="Requeue Multiple Stuck Executions",
|
||||
)
|
||||
async def requeue_multiple_executions(
|
||||
request: StopExecutionsRequest, # Reuse same request model (has execution_ids)
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Requeue multiple stuck QUEUED executions (admin only).
|
||||
|
||||
Uses add_graph_execution with existing graph_exec_id to requeue.
|
||||
|
||||
⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits.
|
||||
|
||||
Args:
|
||||
request: Contains list of execution_ids to requeue
|
||||
|
||||
Returns:
|
||||
Number of executions requeued and success message
|
||||
"""
|
||||
logger.info(
|
||||
f"Admin {user.user_id} requeueing {len(request.execution_ids)} executions"
|
||||
)
|
||||
|
||||
# Get executions by ID list (must be QUEUED)
|
||||
executions = await get_graph_executions(
|
||||
execution_ids=request.execution_ids,
|
||||
statuses=[AgentExecutionStatus.QUEUED],
|
||||
)
|
||||
|
||||
if not executions:
|
||||
return RequeueExecutionResponse(
|
||||
success=False,
|
||||
requeued_count=0,
|
||||
message="No QUEUED executions found to requeue",
|
||||
)
|
||||
|
||||
# Requeue all executions in parallel using add_graph_execution
|
||||
async def requeue_one(exec) -> bool:
|
||||
try:
|
||||
await add_graph_execution(
|
||||
graph_id=exec.graph_id,
|
||||
user_id=exec.user_id,
|
||||
graph_version=exec.graph_version,
|
||||
graph_exec_id=exec.id, # Requeue existing
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to requeue {exec.id}: {e}")
|
||||
return False
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[requeue_one(exec) for exec in executions], return_exceptions=False
|
||||
)
|
||||
|
||||
requeued_count = sum(1 for success in results if success)
|
||||
|
||||
return RequeueExecutionResponse(
|
||||
success=requeued_count > 0,
|
||||
requeued_count=requeued_count,
|
||||
message=f"Requeued {requeued_count} of {len(request.execution_ids)} executions",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/stop",
|
||||
response_model=StopExecutionResponse,
|
||||
summary="Stop Single Execution",
|
||||
)
|
||||
async def stop_single_execution(
|
||||
request: StopExecutionRequest,
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Stop a single execution (admin only).
|
||||
|
||||
Uses robust stop_graph_execution which cascades to children and waits for termination.
|
||||
|
||||
Args:
|
||||
request: Contains execution_id to stop
|
||||
|
||||
Returns:
|
||||
Success status and message
|
||||
"""
|
||||
logger.info(f"Admin {user.user_id} stopping execution {request.execution_id}")
|
||||
|
||||
# Get the execution to find its owner user_id (required by stop_graph_execution)
|
||||
executions = await get_graph_executions(
|
||||
graph_exec_id=request.execution_id,
|
||||
)
|
||||
|
||||
if not executions:
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
execution = executions[0]
|
||||
|
||||
# Use robust stop_graph_execution (cascades to children, waits for termination)
|
||||
await stop_graph_execution(
|
||||
user_id=execution.user_id,
|
||||
graph_exec_id=request.execution_id,
|
||||
wait_timeout=15.0,
|
||||
cascade=True,
|
||||
)
|
||||
|
||||
return StopExecutionResponse(
|
||||
success=True,
|
||||
stopped_count=1,
|
||||
message="Execution stopped successfully",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/stop-bulk",
|
||||
response_model=StopExecutionResponse,
|
||||
summary="Stop Multiple Executions",
|
||||
)
|
||||
async def stop_multiple_executions(
|
||||
request: StopExecutionsRequest,
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Stop multiple active executions (admin only).
|
||||
|
||||
Uses robust stop_graph_execution which cascades to children and waits for termination.
|
||||
|
||||
Args:
|
||||
request: Contains list of execution_ids to stop
|
||||
|
||||
Returns:
|
||||
Number of executions stopped and success message
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
f"Admin {user.user_id} stopping {len(request.execution_ids)} executions"
|
||||
)
|
||||
|
||||
# Get executions by ID list
|
||||
executions = await get_graph_executions(
|
||||
execution_ids=request.execution_ids,
|
||||
)
|
||||
|
||||
if not executions:
|
||||
return StopExecutionResponse(
|
||||
success=False,
|
||||
stopped_count=0,
|
||||
message="No executions found",
|
||||
)
|
||||
|
||||
# Stop all executions in parallel using robust stop_graph_execution
|
||||
async def stop_one(exec) -> bool:
|
||||
try:
|
||||
await stop_graph_execution(
|
||||
user_id=exec.user_id,
|
||||
graph_exec_id=exec.id,
|
||||
wait_timeout=15.0,
|
||||
cascade=True,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop execution {exec.id}: {e}")
|
||||
return False
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[stop_one(exec) for exec in executions], return_exceptions=False
|
||||
)
|
||||
|
||||
stopped_count = sum(1 for success in results if success)
|
||||
|
||||
return StopExecutionResponse(
|
||||
success=stopped_count > 0,
|
||||
stopped_count=stopped_count,
|
||||
message=f"Stopped {stopped_count} of {len(request.execution_ids)} executions",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/cleanup-orphaned",
|
||||
response_model=StopExecutionResponse,
|
||||
summary="Cleanup Orphaned Executions",
|
||||
)
|
||||
async def cleanup_orphaned_executions(
|
||||
request: StopExecutionsRequest,
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Cleanup orphaned executions by directly updating DB status (admin only).
|
||||
For executions in DB but not actually running in executor (old/stale records).
|
||||
|
||||
Args:
|
||||
request: Contains list of execution_ids to cleanup
|
||||
|
||||
Returns:
|
||||
Number of executions cleaned up and success message
|
||||
"""
|
||||
logger.info(
|
||||
f"Admin {user.user_id} cleaning up {len(request.execution_ids)} orphaned executions"
|
||||
)
|
||||
|
||||
cleaned_count = await cleanup_orphaned_executions_bulk(
|
||||
request.execution_ids, user.user_id
|
||||
)
|
||||
|
||||
return StopExecutionResponse(
|
||||
success=cleaned_count > 0,
|
||||
stopped_count=cleaned_count,
|
||||
message=f"Cleaned up {cleaned_count} of {len(request.execution_ids)} orphaned executions",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SCHEDULE DIAGNOSTICS ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class SchedulesListResponse(BaseModel):
|
||||
"""Response model for list of schedules"""
|
||||
|
||||
schedules: List[ScheduleDetail]
|
||||
total: int
|
||||
|
||||
|
||||
class OrphanedSchedulesListResponse(BaseModel):
|
||||
"""Response model for list of orphaned schedules"""
|
||||
|
||||
schedules: List[OrphanedScheduleDetail]
|
||||
total: int
|
||||
|
||||
|
||||
class ScheduleCleanupRequest(BaseModel):
|
||||
"""Request model for cleaning up schedules"""
|
||||
|
||||
schedule_ids: List[str]
|
||||
|
||||
|
||||
class ScheduleCleanupResponse(BaseModel):
|
||||
"""Response model for schedule cleanup operations"""
|
||||
|
||||
success: bool
|
||||
deleted_count: int = 0
|
||||
message: str
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/schedules",
|
||||
response_model=ScheduleHealthMetrics,
|
||||
summary="Get Schedule Diagnostics",
|
||||
)
|
||||
async def get_schedule_diagnostics_endpoint():
|
||||
"""
|
||||
Get comprehensive diagnostic information about schedule health.
|
||||
|
||||
Returns schedule metrics including:
|
||||
- Total schedules (user vs system)
|
||||
- Orphaned schedules by category
|
||||
- Upcoming executions
|
||||
"""
|
||||
logger.info("Getting schedule diagnostics")
|
||||
|
||||
diagnostics = await get_schedule_health_metrics()
|
||||
|
||||
logger.info(
|
||||
f"Schedule diagnostics: total={diagnostics.total_schedules}, "
|
||||
f"user={diagnostics.user_schedules}, "
|
||||
f"orphaned={diagnostics.total_orphaned}"
|
||||
)
|
||||
|
||||
return diagnostics
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/schedules/all",
|
||||
response_model=SchedulesListResponse,
|
||||
summary="List All User Schedules",
|
||||
)
|
||||
async def list_all_schedules(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Get detailed list of all user schedules (excludes system monitoring jobs).
|
||||
|
||||
Args:
|
||||
limit: Maximum number of schedules to return (default 100)
|
||||
offset: Number of schedules to skip (default 0)
|
||||
|
||||
Returns:
|
||||
List of schedules with details
|
||||
"""
|
||||
logger.info(f"Listing all schedules (limit={limit}, offset={offset})")
|
||||
|
||||
schedules = await get_all_schedules_details(limit=limit, offset=offset)
|
||||
|
||||
# Get total count
|
||||
diagnostics = await get_schedule_health_metrics()
|
||||
total = diagnostics.user_schedules
|
||||
|
||||
return SchedulesListResponse(schedules=schedules, total=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/diagnostics/schedules/orphaned",
|
||||
response_model=OrphanedSchedulesListResponse,
|
||||
summary="List Orphaned Schedules",
|
||||
)
|
||||
async def list_orphaned_schedules():
|
||||
"""
|
||||
Get detailed list of orphaned schedules with orphan reasons.
|
||||
|
||||
Returns:
|
||||
List of orphaned schedules categorized by orphan type
|
||||
"""
|
||||
logger.info("Listing orphaned schedules")
|
||||
|
||||
schedules = await get_orphaned_schedules_details()
|
||||
|
||||
return OrphanedSchedulesListResponse(schedules=schedules, total=len(schedules))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/schedules/cleanup-orphaned",
|
||||
response_model=ScheduleCleanupResponse,
|
||||
summary="Cleanup Orphaned Schedules",
|
||||
)
|
||||
async def cleanup_orphaned_schedules(
|
||||
request: ScheduleCleanupRequest,
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Cleanup orphaned schedules by deleting from scheduler (admin only).
|
||||
|
||||
Args:
|
||||
request: Contains list of schedule_ids to delete
|
||||
|
||||
Returns:
|
||||
Number of schedules deleted and success message
|
||||
"""
|
||||
logger.info(
|
||||
f"Admin {user.user_id} cleaning up {len(request.schedule_ids)} orphaned schedules"
|
||||
)
|
||||
|
||||
deleted_count = await cleanup_orphaned_schedules_bulk(
|
||||
request.schedule_ids, user.user_id
|
||||
)
|
||||
|
||||
return ScheduleCleanupResponse(
|
||||
success=deleted_count > 0,
|
||||
deleted_count=deleted_count,
|
||||
message=f"Deleted {deleted_count} of {len(request.schedule_ids)} orphaned schedules",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/stop-all-long-running",
|
||||
response_model=StopExecutionResponse,
|
||||
summary="Stop ALL Long-Running Executions",
|
||||
)
|
||||
async def stop_all_long_running_executions_endpoint(
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Stop ALL long-running executions (RUNNING >24h) by sending cancel signals (admin only).
|
||||
Operates on entire dataset, not limited to pagination.
|
||||
|
||||
Returns:
|
||||
Number of executions stopped and success message
|
||||
"""
|
||||
logger.info(f"Admin {user.user_id} stopping ALL long-running executions")
|
||||
|
||||
stopped_count = await stop_all_long_running_executions(user.user_id)
|
||||
|
||||
return StopExecutionResponse(
|
||||
success=stopped_count > 0,
|
||||
stopped_count=stopped_count,
|
||||
message=f"Stopped {stopped_count} long-running executions",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/cleanup-all-orphaned",
|
||||
response_model=StopExecutionResponse,
|
||||
summary="Cleanup ALL Orphaned Executions",
|
||||
)
|
||||
async def cleanup_all_orphaned_executions(
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Cleanup ALL orphaned executions (>24h old) by directly updating DB status.
|
||||
Operates on all executions, not just paginated results.
|
||||
|
||||
Returns:
|
||||
Number of executions cleaned up and success message
|
||||
"""
|
||||
logger.info(f"Admin {user.user_id} cleaning up ALL orphaned executions")
|
||||
|
||||
# Fetch all orphaned execution IDs
|
||||
execution_ids = await get_all_orphaned_execution_ids()
|
||||
|
||||
if not execution_ids:
|
||||
return StopExecutionResponse(
|
||||
success=True,
|
||||
stopped_count=0,
|
||||
message="No orphaned executions to cleanup",
|
||||
)
|
||||
|
||||
cleaned_count = await cleanup_orphaned_executions_bulk(execution_ids, user.user_id)
|
||||
|
||||
return StopExecutionResponse(
|
||||
success=cleaned_count > 0,
|
||||
stopped_count=cleaned_count,
|
||||
message=f"Cleaned up {cleaned_count} orphaned executions",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/cleanup-all-stuck-queued",
|
||||
response_model=StopExecutionResponse,
|
||||
summary="Cleanup ALL Stuck Queued Executions",
|
||||
)
|
||||
async def cleanup_all_stuck_queued_executions_endpoint(
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Cleanup ALL stuck queued executions (QUEUED >1h) by updating DB status (admin only).
|
||||
Operates on entire dataset, not limited to pagination.
|
||||
|
||||
Returns:
|
||||
Number of executions cleaned up and success message
|
||||
"""
|
||||
logger.info(f"Admin {user.user_id} cleaning up ALL stuck queued executions")
|
||||
|
||||
cleaned_count = await cleanup_all_stuck_queued_executions(user.user_id)
|
||||
|
||||
return StopExecutionResponse(
|
||||
success=cleaned_count > 0,
|
||||
stopped_count=cleaned_count,
|
||||
message=f"Cleaned up {cleaned_count} stuck queued executions",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/diagnostics/executions/requeue-all-stuck",
|
||||
response_model=RequeueExecutionResponse,
|
||||
summary="Requeue ALL Stuck Queued Executions",
|
||||
)
|
||||
async def requeue_all_stuck_executions(
|
||||
user: AuthUser = Security(requires_admin_user),
|
||||
):
|
||||
"""
|
||||
Requeue ALL stuck queued executions (QUEUED >1h) by publishing to RabbitMQ.
|
||||
Operates on all executions, not just paginated results.
|
||||
|
||||
Uses add_graph_execution with existing graph_exec_id to requeue.
|
||||
|
||||
⚠️ WARNING: This will re-execute ALL stuck executions and may cost significant credits.
|
||||
|
||||
Returns:
|
||||
Number of executions requeued and success message
|
||||
"""
|
||||
logger.info(f"Admin {user.user_id} requeueing ALL stuck queued executions")
|
||||
|
||||
# Fetch all stuck queued execution IDs
|
||||
execution_ids = await get_all_stuck_queued_execution_ids()
|
||||
|
||||
if not execution_ids:
|
||||
return RequeueExecutionResponse(
|
||||
success=True,
|
||||
requeued_count=0,
|
||||
message="No stuck queued executions to requeue",
|
||||
)
|
||||
|
||||
# Get stuck executions by ID list (must be QUEUED)
|
||||
executions = await get_graph_executions(
|
||||
execution_ids=execution_ids,
|
||||
statuses=[AgentExecutionStatus.QUEUED],
|
||||
)
|
||||
|
||||
# Requeue all in parallel using add_graph_execution
|
||||
async def requeue_one(exec) -> bool:
|
||||
try:
|
||||
await add_graph_execution(
|
||||
graph_id=exec.graph_id,
|
||||
user_id=exec.user_id,
|
||||
graph_version=exec.graph_version,
|
||||
graph_exec_id=exec.id, # Requeue existing
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to requeue {exec.id}: {e}")
|
||||
return False
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[requeue_one(exec) for exec in executions], return_exceptions=False
|
||||
)
|
||||
|
||||
requeued_count = sum(1 for success in results if success)
|
||||
|
||||
return RequeueExecutionResponse(
|
||||
success=requeued_count > 0,
|
||||
requeued_count=requeued_count,
|
||||
message=f"Requeued {requeued_count} stuck executions",
|
||||
)
|
||||
@@ -1,889 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
|
||||
import backend.api.features.admin.diagnostics_admin_routes as diagnostics_admin_routes
|
||||
from backend.data.diagnostics import (
|
||||
AgentDiagnosticsSummary,
|
||||
ExecutionDiagnosticsSummary,
|
||||
FailedExecutionDetail,
|
||||
OrphanedScheduleDetail,
|
||||
RunningExecutionDetail,
|
||||
ScheduleDetail,
|
||||
ScheduleHealthMetrics,
|
||||
)
|
||||
from backend.data.execution import GraphExecutionMeta
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(diagnostics_admin_routes.router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_execution_diagnostics_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
):
|
||||
"""Test fetching execution diagnostics with invalid state detection"""
|
||||
mock_diagnostics = ExecutionDiagnosticsSummary(
|
||||
running_count=10,
|
||||
queued_db_count=5,
|
||||
rabbitmq_queue_depth=3,
|
||||
cancel_queue_depth=0,
|
||||
orphaned_running=2,
|
||||
orphaned_queued=1,
|
||||
failed_count_1h=5,
|
||||
failed_count_24h=20,
|
||||
failure_rate_24h=0.83,
|
||||
stuck_running_24h=1,
|
||||
stuck_running_1h=3,
|
||||
oldest_running_hours=26.5,
|
||||
stuck_queued_1h=2,
|
||||
queued_never_started=1,
|
||||
invalid_queued_with_start=1, # New invalid state
|
||||
invalid_running_without_start=1, # New invalid state
|
||||
completed_1h=50,
|
||||
completed_24h=1200,
|
||||
throughput_per_hour=50.0,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||
return_value=mock_diagnostics,
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/executions")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Verify new invalid state fields are included
|
||||
assert data["invalid_queued_with_start"] == 1
|
||||
assert data["invalid_running_without_start"] == 1
|
||||
# Verify all expected fields present
|
||||
assert "running_executions" in data
|
||||
assert "orphaned_running" in data
|
||||
assert "failed_count_24h" in data
|
||||
|
||||
|
||||
def test_list_invalid_executions(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
):
|
||||
"""Test listing executions in invalid states (read-only endpoint)"""
|
||||
mock_invalid_executions = [
|
||||
RunningExecutionDetail(
|
||||
execution_id="exec-invalid-1",
|
||||
graph_id="graph-123",
|
||||
graph_name="Test Graph",
|
||||
graph_version=1,
|
||||
user_id="user-123",
|
||||
user_email="test@example.com",
|
||||
status="QUEUED",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
started_at=datetime.now(
|
||||
timezone.utc
|
||||
), # QUEUED but has startedAt - INVALID!
|
||||
queue_status=None,
|
||||
),
|
||||
RunningExecutionDetail(
|
||||
execution_id="exec-invalid-2",
|
||||
graph_id="graph-456",
|
||||
graph_name="Another Graph",
|
||||
graph_version=2,
|
||||
user_id="user-456",
|
||||
user_email="user@example.com",
|
||||
status="RUNNING",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
started_at=None, # RUNNING but no startedAt - INVALID!
|
||||
queue_status=None,
|
||||
),
|
||||
]
|
||||
|
||||
mock_diagnostics = ExecutionDiagnosticsSummary(
|
||||
running_count=10,
|
||||
queued_db_count=5,
|
||||
rabbitmq_queue_depth=3,
|
||||
cancel_queue_depth=0,
|
||||
orphaned_running=0,
|
||||
orphaned_queued=0,
|
||||
failed_count_1h=0,
|
||||
failed_count_24h=0,
|
||||
failure_rate_24h=0.0,
|
||||
stuck_running_24h=0,
|
||||
stuck_running_1h=0,
|
||||
oldest_running_hours=None,
|
||||
stuck_queued_1h=0,
|
||||
queued_never_started=0,
|
||||
invalid_queued_with_start=1,
|
||||
invalid_running_without_start=1,
|
||||
completed_1h=0,
|
||||
completed_24h=0,
|
||||
throughput_per_hour=0.0,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_invalid_executions_details",
|
||||
return_value=mock_invalid_executions,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||
return_value=mock_diagnostics,
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/executions/invalid?limit=100&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 2 # Sum of both invalid state types
|
||||
assert len(data["executions"]) == 2
|
||||
# Verify both types of invalid states are returned
|
||||
assert data["executions"][0]["execution_id"] in [
|
||||
"exec-invalid-1",
|
||||
"exec-invalid-2",
|
||||
]
|
||||
assert data["executions"][1]["execution_id"] in [
|
||||
"exec-invalid-1",
|
||||
"exec-invalid-2",
|
||||
]
|
||||
|
||||
|
||||
def test_requeue_single_execution_with_add_graph_execution(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
admin_user_id: str,
|
||||
):
|
||||
"""Test requeueing uses add_graph_execution in requeue mode"""
|
||||
mock_exec_meta = GraphExecutionMeta(
|
||||
id="exec-stuck-123",
|
||||
user_id="user-123",
|
||||
graph_id="graph-456",
|
||||
graph_version=1,
|
||||
inputs=None,
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
preset_id=None,
|
||||
status=AgentExecutionStatus.QUEUED,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=[mock_exec_meta],
|
||||
)
|
||||
|
||||
mock_add_graph_execution = mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.add_graph_execution",
|
||||
return_value=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/requeue",
|
||||
json={"execution_id": "exec-stuck-123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["requeued_count"] == 1
|
||||
|
||||
# Verify it used add_graph_execution in requeue mode
|
||||
mock_add_graph_execution.assert_called_once()
|
||||
call_kwargs = mock_add_graph_execution.call_args.kwargs
|
||||
assert call_kwargs["graph_exec_id"] == "exec-stuck-123" # Requeue mode!
|
||||
assert call_kwargs["graph_id"] == "graph-456"
|
||||
assert call_kwargs["user_id"] == "user-123"
|
||||
|
||||
|
||||
def test_stop_single_execution_with_stop_graph_execution(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
admin_user_id: str,
|
||||
):
|
||||
"""Test stopping uses robust stop_graph_execution"""
|
||||
mock_exec_meta = GraphExecutionMeta(
|
||||
id="exec-running-123",
|
||||
user_id="user-789",
|
||||
graph_id="graph-999",
|
||||
graph_version=2,
|
||||
inputs=None,
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
preset_id=None,
|
||||
status=AgentExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=[mock_exec_meta],
|
||||
)
|
||||
|
||||
mock_stop_graph_execution = mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.stop_graph_execution",
|
||||
return_value=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/stop",
|
||||
json={"execution_id": "exec-running-123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 1
|
||||
|
||||
# Verify it used stop_graph_execution with cascade
|
||||
mock_stop_graph_execution.assert_called_once()
|
||||
call_kwargs = mock_stop_graph_execution.call_args.kwargs
|
||||
assert call_kwargs["graph_exec_id"] == "exec-running-123"
|
||||
assert call_kwargs["user_id"] == "user-789"
|
||||
assert call_kwargs["cascade"] is True # Stops children too!
|
||||
assert call_kwargs["wait_timeout"] == 15.0
|
||||
|
||||
|
||||
def test_requeue_not_queued_execution_fails(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
):
|
||||
"""Test that requeue fails if execution is not in QUEUED status"""
|
||||
# Mock an execution that's RUNNING (not QUEUED)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=[], # No QUEUED executions found
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/requeue",
|
||||
json={"execution_id": "exec-running-123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found or not in QUEUED status" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_list_invalid_executions_no_bulk_actions(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
):
|
||||
"""Verify invalid executions endpoint is read-only (no bulk actions)"""
|
||||
# This is a documentation test - the endpoint exists but should not
|
||||
# have corresponding cleanup/stop/requeue endpoints
|
||||
|
||||
# These endpoints should NOT exist for invalid states:
|
||||
invalid_bulk_endpoints = [
|
||||
"/admin/diagnostics/executions/cleanup-invalid",
|
||||
"/admin/diagnostics/executions/stop-invalid",
|
||||
"/admin/diagnostics/executions/requeue-invalid",
|
||||
]
|
||||
|
||||
for endpoint in invalid_bulk_endpoints:
|
||||
response = client.post(endpoint, json={"execution_ids": ["test"]})
|
||||
assert response.status_code == 404, f"{endpoint} should not exist (read-only)"
|
||||
|
||||
|
||||
def test_execution_ids_filter_efficiency(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
):
|
||||
"""Test that bulk operations use efficient execution_ids filter"""
|
||||
mock_exec_metas = [
|
||||
GraphExecutionMeta(
|
||||
id=f"exec-{i}",
|
||||
user_id=f"user-{i}",
|
||||
graph_id="graph-123",
|
||||
graph_version=1,
|
||||
inputs=None,
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
preset_id=None,
|
||||
status=AgentExecutionStatus.QUEUED,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=datetime.now(timezone.utc),
|
||||
stats=None,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
mock_get_graph_executions = mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=mock_exec_metas,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.add_graph_execution",
|
||||
return_value=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/requeue-bulk",
|
||||
json={"execution_ids": ["exec-0", "exec-1", "exec-2"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify it used execution_ids filter (not fetching all queued)
|
||||
mock_get_graph_executions.assert_called_once()
|
||||
call_kwargs = mock_get_graph_executions.call_args.kwargs
|
||||
assert "execution_ids" in call_kwargs
|
||||
assert call_kwargs["execution_ids"] == ["exec-0", "exec-1", "exec-2"]
|
||||
assert call_kwargs["statuses"] == [AgentExecutionStatus.QUEUED]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: reusable mock diagnostics summary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_diagnostics(**overrides) -> ExecutionDiagnosticsSummary:
|
||||
defaults = dict(
|
||||
running_count=10,
|
||||
queued_db_count=5,
|
||||
rabbitmq_queue_depth=3,
|
||||
cancel_queue_depth=0,
|
||||
orphaned_running=2,
|
||||
orphaned_queued=1,
|
||||
failed_count_1h=5,
|
||||
failed_count_24h=20,
|
||||
failure_rate_24h=0.83,
|
||||
stuck_running_24h=3,
|
||||
stuck_running_1h=5,
|
||||
oldest_running_hours=26.5,
|
||||
stuck_queued_1h=2,
|
||||
queued_never_started=1,
|
||||
invalid_queued_with_start=1,
|
||||
invalid_running_without_start=1,
|
||||
completed_1h=50,
|
||||
completed_24h=1200,
|
||||
throughput_per_hour=50.0,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return ExecutionDiagnosticsSummary(**defaults)
|
||||
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
def _make_mock_execution(
|
||||
exec_id: str = "exec-1",
|
||||
status: str = "RUNNING",
|
||||
started_at: datetime | None | object = _SENTINEL,
|
||||
) -> RunningExecutionDetail:
|
||||
return RunningExecutionDetail(
|
||||
execution_id=exec_id,
|
||||
graph_id="graph-123",
|
||||
graph_name="Test Graph",
|
||||
graph_version=1,
|
||||
user_id="user-123",
|
||||
user_email="test@example.com",
|
||||
status=status,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
started_at=(
|
||||
datetime.now(timezone.utc) if started_at is _SENTINEL else started_at
|
||||
),
|
||||
queue_status=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_mock_failed_execution(
|
||||
exec_id: str = "exec-fail-1",
|
||||
) -> FailedExecutionDetail:
|
||||
return FailedExecutionDetail(
|
||||
execution_id=exec_id,
|
||||
graph_id="graph-123",
|
||||
graph_name="Test Graph",
|
||||
graph_version=1,
|
||||
user_id="user-123",
|
||||
user_email="test@example.com",
|
||||
status="FAILED",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
started_at=datetime.now(timezone.utc),
|
||||
failed_at=datetime.now(timezone.utc),
|
||||
error_message="Something went wrong",
|
||||
)
|
||||
|
||||
|
||||
def _make_mock_schedule_health(**overrides) -> ScheduleHealthMetrics:
|
||||
defaults = dict(
|
||||
total_schedules=15,
|
||||
user_schedules=10,
|
||||
system_schedules=5,
|
||||
orphaned_deleted_graph=2,
|
||||
orphaned_no_library_access=1,
|
||||
orphaned_invalid_credentials=0,
|
||||
orphaned_validation_failed=0,
|
||||
total_orphaned=3,
|
||||
schedules_next_hour=4,
|
||||
schedules_next_24h=8,
|
||||
total_runs_next_hour=12,
|
||||
total_runs_next_24h=48,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return ScheduleHealthMetrics(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET endpoints: execution list variants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_list_running_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_execs = [
|
||||
_make_mock_execution("exec-run-1"),
|
||||
_make_mock_execution("exec-run-2"),
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_running_executions_details",
|
||||
return_value=mock_execs,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||
return_value=_make_mock_diagnostics(),
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/executions/running?limit=50&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 15 # running_count(10) + queued_db_count(5)
|
||||
assert len(data["executions"]) == 2
|
||||
assert data["executions"][0]["execution_id"] == "exec-run-1"
|
||||
|
||||
|
||||
def test_list_orphaned_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_execs = [_make_mock_execution("exec-orphan-1", status="RUNNING")]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_orphaned_executions_details",
|
||||
return_value=mock_execs,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||
return_value=_make_mock_diagnostics(),
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/executions/orphaned?limit=50&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 3 # orphaned_running(2) + orphaned_queued(1)
|
||||
assert len(data["executions"]) == 1
|
||||
|
||||
|
||||
def test_list_failed_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_execs = [_make_mock_failed_execution("exec-fail-1")]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_failed_executions_details",
|
||||
return_value=mock_execs,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_failed_executions_count",
|
||||
return_value=42,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/diagnostics/executions/failed?limit=50&offset=0&hours=24"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 42
|
||||
assert len(data["executions"]) == 1
|
||||
assert data["executions"][0]["error_message"] == "Something went wrong"
|
||||
|
||||
|
||||
def test_list_long_running_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_execs = [_make_mock_execution("exec-long-1")]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_long_running_executions_details",
|
||||
return_value=mock_execs,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||
return_value=_make_mock_diagnostics(),
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/diagnostics/executions/long-running?limit=50&offset=0"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 3 # stuck_running_24h
|
||||
assert len(data["executions"]) == 1
|
||||
|
||||
|
||||
def test_list_stuck_queued_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_execs = [
|
||||
_make_mock_execution("exec-stuck-1", status="QUEUED", started_at=None)
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_stuck_queued_executions_details",
|
||||
return_value=mock_execs,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||
return_value=_make_mock_diagnostics(),
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/diagnostics/executions/stuck-queued?limit=50&offset=0"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 2 # stuck_queued_1h
|
||||
assert len(data["executions"]) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET endpoints: agent + schedule diagnostics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_agent_diagnostics(mocker: pytest_mock.MockFixture):
|
||||
mock_diag = AgentDiagnosticsSummary(
|
||||
agents_with_active_executions=7,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_agent_diagnostics",
|
||||
return_value=mock_diag,
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/agents")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["agents_with_active_executions"] == 7
|
||||
|
||||
|
||||
def test_get_schedule_diagnostics(mocker: pytest_mock.MockFixture):
|
||||
mock_metrics = _make_mock_schedule_health()
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_schedule_health_metrics",
|
||||
return_value=mock_metrics,
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/schedules")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_schedules"] == 10
|
||||
assert data["total_orphaned"] == 3
|
||||
assert data["total_runs_next_hour"] == 12
|
||||
|
||||
|
||||
def test_list_all_schedules(mocker: pytest_mock.MockFixture):
|
||||
mock_schedules = [
|
||||
ScheduleDetail(
|
||||
schedule_id="sched-1",
|
||||
schedule_name="Daily Run",
|
||||
graph_id="graph-1",
|
||||
graph_name="My Agent",
|
||||
graph_version=1,
|
||||
user_id="user-1",
|
||||
user_email="alice@example.com",
|
||||
cron="0 9 * * *",
|
||||
timezone="UTC",
|
||||
next_run_time=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_all_schedules_details",
|
||||
return_value=mock_schedules,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_schedule_health_metrics",
|
||||
return_value=_make_mock_schedule_health(),
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/schedules/all?limit=50&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 10
|
||||
assert len(data["schedules"]) == 1
|
||||
assert data["schedules"][0]["schedule_name"] == "Daily Run"
|
||||
|
||||
|
||||
def test_list_orphaned_schedules(mocker: pytest_mock.MockFixture):
|
||||
mock_orphans = [
|
||||
OrphanedScheduleDetail(
|
||||
schedule_id="sched-orphan-1",
|
||||
schedule_name="Ghost Schedule",
|
||||
graph_id="graph-deleted",
|
||||
graph_version=1,
|
||||
user_id="user-1",
|
||||
orphan_reason="deleted_graph",
|
||||
error_detail=None,
|
||||
next_run_time=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_orphaned_schedules_details",
|
||||
return_value=mock_orphans,
|
||||
)
|
||||
|
||||
response = client.get("/admin/diagnostics/schedules/orphaned")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert data["schedules"][0]["orphan_reason"] == "deleted_graph"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST endpoints: bulk stop, cleanup, requeue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_stop_multiple_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_exec_metas = [
|
||||
GraphExecutionMeta(
|
||||
id=f"exec-{i}",
|
||||
user_id=f"user-{i}",
|
||||
graph_id="graph-123",
|
||||
graph_version=1,
|
||||
inputs=None,
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
preset_id=None,
|
||||
status=AgentExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=None,
|
||||
stats=None,
|
||||
)
|
||||
for i in range(2)
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=mock_exec_metas,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.stop_graph_execution",
|
||||
return_value=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/stop-bulk",
|
||||
json={"execution_ids": ["exec-0", "exec-1"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 2
|
||||
|
||||
|
||||
def test_stop_multiple_executions_none_found(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/stop-bulk",
|
||||
json={"execution_ids": ["nonexistent"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
assert data["stopped_count"] == 0
|
||||
|
||||
|
||||
def test_cleanup_orphaned_executions(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_executions_bulk",
|
||||
return_value=3,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/cleanup-orphaned",
|
||||
json={"execution_ids": ["exec-1", "exec-2", "exec-3"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 3
|
||||
|
||||
|
||||
def test_cleanup_orphaned_schedules(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_schedules_bulk",
|
||||
return_value=2,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/schedules/cleanup-orphaned",
|
||||
json={"schedule_ids": ["sched-1", "sched-2"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["deleted_count"] == 2
|
||||
|
||||
|
||||
def test_stop_all_long_running_executions(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.stop_all_long_running_executions",
|
||||
return_value=5,
|
||||
)
|
||||
|
||||
response = client.post("/admin/diagnostics/executions/stop-all-long-running")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 5
|
||||
|
||||
|
||||
def test_cleanup_all_orphaned_executions(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_all_orphaned_execution_ids",
|
||||
return_value=["exec-1", "exec-2"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_executions_bulk",
|
||||
return_value=2,
|
||||
)
|
||||
|
||||
response = client.post("/admin/diagnostics/executions/cleanup-all-orphaned")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 2
|
||||
|
||||
|
||||
def test_cleanup_all_orphaned_executions_none(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_all_orphaned_execution_ids",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.post("/admin/diagnostics/executions/cleanup-all-orphaned")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 0
|
||||
assert "No orphaned" in data["message"]
|
||||
|
||||
|
||||
def test_cleanup_all_stuck_queued_executions(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.cleanup_all_stuck_queued_executions",
|
||||
return_value=4,
|
||||
)
|
||||
|
||||
response = client.post("/admin/diagnostics/executions/cleanup-all-stuck-queued")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["stopped_count"] == 4
|
||||
|
||||
|
||||
def test_requeue_all_stuck_executions(mocker: pytest_mock.MockFixture):
|
||||
mock_exec_metas = [
|
||||
GraphExecutionMeta(
|
||||
id=f"exec-stuck-{i}",
|
||||
user_id=f"user-{i}",
|
||||
graph_id="graph-123",
|
||||
graph_version=1,
|
||||
inputs=None,
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
preset_id=None,
|
||||
status=AgentExecutionStatus.QUEUED,
|
||||
started_at=None,
|
||||
ended_at=None,
|
||||
stats=None,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_all_stuck_queued_execution_ids",
|
||||
return_value=["exec-stuck-0", "exec-stuck-1", "exec-stuck-2"],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=mock_exec_metas,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.add_graph_execution",
|
||||
return_value=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.post("/admin/diagnostics/executions/requeue-all-stuck")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["requeued_count"] == 3
|
||||
|
||||
|
||||
def test_requeue_all_stuck_executions_none(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_all_stuck_queued_execution_ids",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.post("/admin/diagnostics/executions/requeue-all-stuck")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["requeued_count"] == 0
|
||||
assert "No stuck" in data["message"]
|
||||
|
||||
|
||||
def test_requeue_bulk_none_found(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/requeue-bulk",
|
||||
json={"execution_ids": ["nonexistent"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
assert data["requeued_count"] == 0
|
||||
|
||||
|
||||
def test_stop_single_execution_not_found(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/diagnostics/executions/stop",
|
||||
json={"execution_id": "nonexistent"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"]
|
||||
@@ -14,70 +14,3 @@ class UserHistoryResponse(BaseModel):
|
||||
class AddUserCreditsResponse(BaseModel):
|
||||
new_balance: int
|
||||
transaction_key: str
|
||||
|
||||
|
||||
class ExecutionDiagnosticsResponse(BaseModel):
|
||||
"""Response model for execution diagnostics"""
|
||||
|
||||
# Current execution state
|
||||
running_executions: int
|
||||
queued_executions_db: int
|
||||
queued_executions_rabbitmq: int
|
||||
cancel_queue_depth: int
|
||||
|
||||
# Orphaned execution detection
|
||||
orphaned_running: int
|
||||
orphaned_queued: int
|
||||
|
||||
# Failure metrics
|
||||
failed_count_1h: int
|
||||
failed_count_24h: int
|
||||
failure_rate_24h: float
|
||||
|
||||
# Long-running detection
|
||||
stuck_running_24h: int
|
||||
stuck_running_1h: int
|
||||
oldest_running_hours: float | None
|
||||
|
||||
# Stuck queued detection
|
||||
stuck_queued_1h: int
|
||||
queued_never_started: int
|
||||
|
||||
# Invalid state detection (data corruption - no auto-actions)
|
||||
invalid_queued_with_start: int
|
||||
invalid_running_without_start: int
|
||||
|
||||
# Throughput metrics
|
||||
completed_1h: int
|
||||
completed_24h: int
|
||||
throughput_per_hour: float
|
||||
|
||||
timestamp: str
|
||||
|
||||
|
||||
class AgentDiagnosticsResponse(BaseModel):
|
||||
"""Response model for agent diagnostics"""
|
||||
|
||||
agents_with_active_executions: int
|
||||
timestamp: str
|
||||
|
||||
|
||||
class ScheduleHealthMetrics(BaseModel):
|
||||
"""Response model for schedule diagnostics"""
|
||||
|
||||
total_schedules: int
|
||||
user_schedules: int
|
||||
system_schedules: int
|
||||
|
||||
# Orphan detection
|
||||
orphaned_deleted_graph: int
|
||||
orphaned_no_library_access: int
|
||||
orphaned_invalid_credentials: int
|
||||
orphaned_validation_failed: int
|
||||
total_orphaned: int
|
||||
|
||||
# Upcoming
|
||||
schedules_next_hour: int
|
||||
schedules_next_24h: int
|
||||
|
||||
timestamp: str
|
||||
|
||||
@@ -32,10 +32,10 @@ router = APIRouter(
|
||||
class UserRateLimitResponse(BaseModel):
|
||||
user_id: str
|
||||
user_email: Optional[str] = None
|
||||
daily_cost_limit_microdollars: int
|
||||
weekly_cost_limit_microdollars: int
|
||||
daily_cost_used_microdollars: int
|
||||
weekly_cost_used_microdollars: int
|
||||
daily_token_limit: int
|
||||
weekly_token_limit: int
|
||||
daily_tokens_used: int
|
||||
weekly_tokens_used: int
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
@@ -101,19 +101,17 @@ async def get_user_rate_limit(
|
||||
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
|
||||
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
resolved_id,
|
||||
config.daily_cost_limit_microdollars,
|
||||
config.weekly_cost_limit_microdollars,
|
||||
resolved_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=resolved_id,
|
||||
user_email=resolved_email,
|
||||
daily_cost_limit_microdollars=daily_limit,
|
||||
weekly_cost_limit_microdollars=weekly_limit,
|
||||
daily_cost_used_microdollars=usage.daily.used,
|
||||
weekly_cost_used_microdollars=usage.weekly.used,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
@@ -143,9 +141,7 @@ async def reset_user_rate_limit(
|
||||
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
|
||||
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id,
|
||||
config.daily_cost_limit_microdollars,
|
||||
config.weekly_cost_limit_microdollars,
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
@@ -158,10 +154,10 @@ async def reset_user_rate_limit(
|
||||
return UserRateLimitResponse(
|
||||
user_id=user_id,
|
||||
user_email=resolved_email,
|
||||
daily_cost_limit_microdollars=daily_limit,
|
||||
weekly_cost_limit_microdollars=weekly_limit,
|
||||
daily_cost_used_microdollars=usage.daily.used,
|
||||
weekly_cost_used_microdollars=usage.weekly.used,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ def _patch_rate_limit_deps(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.BASIC),
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
@@ -85,11 +85,11 @@ def test_get_rate_limit(
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] == _TARGET_EMAIL
|
||||
assert data["daily_cost_limit_microdollars"] == 2_500_000
|
||||
assert data["weekly_cost_limit_microdollars"] == 12_500_000
|
||||
assert data["daily_cost_used_microdollars"] == 500_000
|
||||
assert data["weekly_cost_used_microdollars"] == 3_000_000
|
||||
assert data["tier"] == "BASIC"
|
||||
assert data["daily_token_limit"] == 2_500_000
|
||||
assert data["weekly_token_limit"] == 12_500_000
|
||||
assert data["daily_tokens_used"] == 500_000
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
@@ -117,7 +117,7 @@ def test_get_rate_limit_by_email(
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] == _TARGET_EMAIL
|
||||
assert data["daily_cost_limit_microdollars"] == 2_500_000
|
||||
assert data["daily_token_limit"] == 2_500_000
|
||||
|
||||
|
||||
def test_get_rate_limit_by_email_not_found(
|
||||
@@ -160,10 +160,10 @@ def test_reset_user_usage_daily_only(
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily_cost_used_microdollars"] == 0
|
||||
assert data["daily_tokens_used"] == 0
|
||||
# Weekly is untouched
|
||||
assert data["weekly_cost_used_microdollars"] == 3_000_000
|
||||
assert data["tier"] == "BASIC"
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
|
||||
|
||||
@@ -192,9 +192,9 @@ def test_reset_user_usage_daily_and_weekly(
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily_cost_used_microdollars"] == 0
|
||||
assert data["weekly_cost_used_microdollars"] == 0
|
||||
assert data["tier"] == "BASIC"
|
||||
assert data["daily_tokens_used"] == 0
|
||||
assert data["weekly_tokens_used"] == 0
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
|
||||
|
||||
@@ -231,7 +231,7 @@ def test_get_rate_limit_email_lookup_failure(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.BASIC),
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
@@ -324,7 +324,7 @@ def test_set_user_tier(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.BASIC,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
)
|
||||
mock_set = mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
@@ -347,7 +347,7 @@ def test_set_user_tier_downgrade(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test downgrading a user's tier from PRO to BASIC."""
|
||||
"""Test downgrading a user's tier from PRO to FREE."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
@@ -365,14 +365,14 @@ def test_set_user_tier_downgrade(
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "BASIC"},
|
||||
json={"user_id": target_user_id, "tier": "FREE"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "BASIC"
|
||||
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.BASIC)
|
||||
assert data["tier"] == "FREE"
|
||||
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.FREE)
|
||||
|
||||
|
||||
def test_set_user_tier_invalid_tier(
|
||||
@@ -456,7 +456,7 @@ def test_set_user_tier_db_failure(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.BASIC,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
from uuid import uuid4
|
||||
@@ -9,11 +10,11 @@ from uuid import uuid4
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.builder_context import resolve_session_permissions
|
||||
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
|
||||
from backend.copilot.db import get_chat_messages_paginated
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
@@ -25,18 +26,11 @@ from backend.copilot.model import (
|
||||
create_chat_session,
|
||||
delete_chat_session,
|
||||
get_chat_session,
|
||||
get_or_create_builder_session,
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.pending_message_helpers import (
|
||||
QueuePendingMessageResponse,
|
||||
is_turn_in_flight,
|
||||
queue_pending_for_http,
|
||||
)
|
||||
from backend.copilot.pending_messages import peek_pending_messages
|
||||
from backend.copilot.rate_limit import (
|
||||
CoPilotUsagePublic,
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
acquire_reset_lock,
|
||||
check_rate_limit,
|
||||
@@ -47,14 +41,7 @@ from backend.copilot.rate_limit import (
|
||||
release_reset_lock,
|
||||
reset_daily_usage,
|
||||
)
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.service import strip_injected_context_for_display
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
from backend.copilot.tools.models import (
|
||||
@@ -82,14 +69,13 @@ from backend.copilot.tools.models import (
|
||||
NoResultsResponse,
|
||||
SetupRequirementsResponse,
|
||||
SuggestedGoalResponse,
|
||||
TodoWriteResponse,
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import get_business_understanding
|
||||
from backend.data.workspace import build_files_block, resolve_workspace_files
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import InsufficientBalanceError, NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -99,6 +85,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
|
||||
)
|
||||
|
||||
|
||||
async def _validate_and_get_session(
|
||||
session_id: str,
|
||||
@@ -143,7 +133,7 @@ def _strip_injected_context(message: dict) -> dict:
|
||||
class StreamChatRequest(BaseModel):
|
||||
"""Request model for streaming chat with optional context."""
|
||||
|
||||
message: str = Field(max_length=64_000)
|
||||
message: str
|
||||
is_user_message: bool = True
|
||||
context: dict[str, str] | None = None # {url: str, content: str}
|
||||
file_ids: list[str] | None = Field(
|
||||
@@ -161,53 +151,16 @@ class StreamChatRequest(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class QueuePendingMessageRequest(BaseModel):
|
||||
"""Request model for queueing a follow-up while a turn is running."""
|
||||
|
||||
message: str = Field(max_length=64_000)
|
||||
context: dict[str, str] | None = None
|
||||
file_ids: list[str] | None = Field(default=None, max_length=20)
|
||||
|
||||
|
||||
class PeekPendingMessagesResponse(BaseModel):
|
||||
"""Response for the pending-message peek (GET) endpoint.
|
||||
|
||||
Returns a read-only view of the pending buffer — messages are NOT
|
||||
consumed. The frontend uses this to restore the queued-message
|
||||
indicator after a page refresh and to decide when to clear it once
|
||||
a turn has ended.
|
||||
"""
|
||||
|
||||
messages: list[str]
|
||||
count: int
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""Request model for creating (or get-or-creating) a chat session.
|
||||
|
||||
Two modes, selected by the body:
|
||||
|
||||
- Default: create a fresh session. ``dry_run`` is a **top-level**
|
||||
field — do not nest it inside ``metadata``.
|
||||
- Builder-bound: when ``builder_graph_id`` is set, the endpoint
|
||||
switches to **get-or-create** keyed on
|
||||
``(user_id, builder_graph_id)``. The builder panel calls this on
|
||||
mount so the chat persists across refreshes. Graph ownership is
|
||||
validated inside :func:`get_or_create_builder_session`. Write-side
|
||||
scope is enforced per-tool (``edit_agent`` / ``run_agent`` reject
|
||||
any ``agent_id`` other than the bound graph) and a small blacklist
|
||||
hides tools that conflict with the panel's scope
|
||||
(``create_agent`` / ``customize_agent`` / ``get_agent_building_guide``
|
||||
— see :data:`BUILDER_BLOCKED_TOOLS`). Read-side lookups
|
||||
(``find_block``, ``find_agent``, ``search_docs``, …) stay open.
|
||||
"""Request model for creating a new chat session.
|
||||
|
||||
``dry_run`` is a **top-level** field — do not nest it inside ``metadata``.
|
||||
Extra/unknown fields are rejected (422) to prevent silent mis-use.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
dry_run: bool = False
|
||||
builder_graph_id: str | None = Field(default=None, max_length=128)
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
@@ -224,11 +177,6 @@ class ActiveStreamInfo(BaseModel):
|
||||
|
||||
turn_id: str
|
||||
last_message_id: str # Redis Stream message ID for resumption
|
||||
# ISO-8601 timestamp (UTC) marking when the backend registered the turn
|
||||
# as running. Lets the frontend seed its elapsed-time counter so restored
|
||||
# turns show honest "time since turn started" instead of the misleading
|
||||
# "time since this mount resumed the SSE".
|
||||
started_at: str | None = None
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
@@ -320,11 +268,8 @@ async def list_sessions(
|
||||
redis = await get_redis_async()
|
||||
pipe = redis.pipeline(transaction=False)
|
||||
for session in sessions:
|
||||
# Use the canonical helper so the hash-tag braces match every
|
||||
# other writer; building the key inline drops the braces and
|
||||
# silently misses every running session on cluster mode.
|
||||
pipe.hget(
|
||||
stream_registry.get_session_meta_key(session.session_id),
|
||||
f"{config.session_meta_prefix}{session.session_id}",
|
||||
"status",
|
||||
)
|
||||
statuses = await pipe.execute()
|
||||
@@ -360,43 +305,29 @@ async def create_session(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
request: CreateSessionRequest | None = None,
|
||||
) -> CreateSessionResponse:
|
||||
"""Create (or get-or-create) a chat session.
|
||||
"""
|
||||
Create a new chat session.
|
||||
|
||||
Two modes, selected by the request body:
|
||||
|
||||
- Default: create a fresh session for the user. ``dry_run=True`` forces
|
||||
run_block and run_agent calls to use dry-run simulation.
|
||||
- Builder-bound: when ``builder_graph_id`` is set, get-or-create keyed
|
||||
on ``(user_id, builder_graph_id)``. Returns the existing session for
|
||||
that graph or creates one locked to it. Graph ownership is validated
|
||||
inside :func:`get_or_create_builder_session`; raises 404 on
|
||||
unauthorized access. Write-side scope is enforced per-tool
|
||||
(``edit_agent`` / ``run_agent`` reject any ``agent_id`` other than
|
||||
the bound graph) and a small blacklist hides tools that conflict
|
||||
with the panel's scope (see :data:`BUILDER_BLOCKED_TOOLS`).
|
||||
Initiates a new chat session for the authenticated user.
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user ID parsed from the JWT (required).
|
||||
request: Optional request body with ``dry_run`` and/or
|
||||
``builder_graph_id``.
|
||||
request: Optional request body. When provided, ``dry_run=True``
|
||||
forces run_block and run_agent calls to use dry-run simulation.
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the resulting session.
|
||||
CreateSessionResponse: Details of the created session.
|
||||
|
||||
"""
|
||||
dry_run = request.dry_run if request else False
|
||||
builder_graph_id = request.builder_graph_id if request else None
|
||||
|
||||
logger.info(
|
||||
f"Creating session with user_id: "
|
||||
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
|
||||
f"{', dry_run=True' if dry_run else ''}"
|
||||
f"{f', builder_graph_id={builder_graph_id}' if builder_graph_id else ''}"
|
||||
)
|
||||
|
||||
if builder_graph_id:
|
||||
session = await get_or_create_builder_session(user_id, builder_graph_id)
|
||||
else:
|
||||
session = await create_chat_session(user_id, dry_run=dry_run)
|
||||
session = await create_chat_session(user_id, dry_run=dry_run)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
@@ -552,7 +483,6 @@ async def get_session(
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
started_at=active_session.created_at.isoformat(),
|
||||
)
|
||||
|
||||
# Skip session metadata on "load more" — frontend only needs messages
|
||||
@@ -593,27 +523,23 @@ async def get_session(
|
||||
)
|
||||
async def get_copilot_usage(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> CoPilotUsagePublic:
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get CoPilot usage status for the authenticated user.
|
||||
|
||||
Returns the percentage of the daily/weekly allowance used — not the
|
||||
raw spend or cap — so clients cannot derive per-turn cost or platform
|
||||
margins. Global defaults sourced from LaunchDarkly (falling back to
|
||||
config). Includes the user's rate-limit tier.
|
||||
Returns current token usage vs limits for daily and weekly windows.
|
||||
Global defaults sourced from LaunchDarkly (falling back to config).
|
||||
Includes the user's rate-limit tier.
|
||||
"""
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id,
|
||||
config.daily_cost_limit_microdollars,
|
||||
config.weekly_cost_limit_microdollars,
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
status = await get_usage_status(
|
||||
return await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_cost_limit=daily_limit,
|
||||
weekly_cost_limit=weekly_limit,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
tier=tier,
|
||||
)
|
||||
return CoPilotUsagePublic.from_status(status)
|
||||
|
||||
|
||||
class RateLimitResetResponse(BaseModel):
|
||||
@@ -622,9 +548,7 @@ class RateLimitResetResponse(BaseModel):
|
||||
success: bool
|
||||
credits_charged: int = Field(description="Credits charged (in cents)")
|
||||
remaining_balance: int = Field(description="Credit balance after charge (in cents)")
|
||||
usage: CoPilotUsagePublic = Field(
|
||||
description="Updated usage status after reset (percentages only)"
|
||||
)
|
||||
usage: CoPilotUsageStatus = Field(description="Updated usage status after reset")
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -648,7 +572,7 @@ async def reset_copilot_usage(
|
||||
) -> RateLimitResetResponse:
|
||||
"""Reset the daily CoPilot rate limit by spending credits.
|
||||
|
||||
Allows users who have hit their daily cost limit to spend credits
|
||||
Allows users who have hit their daily token limit to spend credits
|
||||
to reset their daily usage counter and continue working.
|
||||
Returns 400 if the feature is disabled or the user is not over the limit.
|
||||
Returns 402 if the user has insufficient credits.
|
||||
@@ -667,9 +591,7 @@ async def reset_copilot_usage(
|
||||
)
|
||||
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id,
|
||||
config.daily_cost_limit_microdollars,
|
||||
config.weekly_cost_limit_microdollars,
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
|
||||
if daily_limit <= 0:
|
||||
@@ -706,8 +628,8 @@ async def reset_copilot_usage(
|
||||
# used for limit checks, not returned to the client.)
|
||||
usage_status = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_cost_limit=daily_limit,
|
||||
weekly_cost_limit=weekly_limit,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
tier=tier,
|
||||
)
|
||||
if daily_limit > 0 and usage_status.daily.used < daily_limit:
|
||||
@@ -742,7 +664,7 @@ async def reset_copilot_usage(
|
||||
|
||||
# Reset daily usage in Redis. If this fails, refund the credits
|
||||
# so the user is not charged for a service they did not receive.
|
||||
if not await reset_daily_usage(user_id, daily_cost_limit=daily_limit):
|
||||
if not await reset_daily_usage(user_id, daily_token_limit=daily_limit):
|
||||
# Compensate: refund the charged credits.
|
||||
refunded = False
|
||||
try:
|
||||
@@ -778,11 +700,11 @@ async def reset_copilot_usage(
|
||||
finally:
|
||||
await release_reset_lock(user_id)
|
||||
|
||||
# Return updated usage status (public schema — percentages only).
|
||||
# Return updated usage status.
|
||||
updated_usage = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_cost_limit=daily_limit,
|
||||
weekly_cost_limit=weekly_limit,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
tier=tier,
|
||||
)
|
||||
@@ -791,7 +713,7 @@ async def reset_copilot_usage(
|
||||
success=True,
|
||||
credits_charged=cost,
|
||||
remaining_balance=remaining,
|
||||
usage=CoPilotUsagePublic.from_status(updated_usage),
|
||||
usage=updated_usage,
|
||||
)
|
||||
|
||||
|
||||
@@ -840,81 +762,38 @@ async def cancel_session_task(
|
||||
return CancelSessionResponse(cancelled=True)
|
||||
|
||||
|
||||
def _ui_message_stream_headers() -> dict[str, str]:
|
||||
return {
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
}
|
||||
|
||||
|
||||
def _empty_ui_message_stream_response() -> StreamingResponse:
|
||||
# Stable placeholder messageId for the empty queued-mid-turn stream.
|
||||
# Real turns generate per-message UUIDs via the executor; this stream
|
||||
# has no message to attach to, but the AI SDK parser still requires a
|
||||
# non-empty ``messageId`` field on ``StreamStart``.
|
||||
message_id = uuid4().hex
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
# Vercel AI SDK's UI-message-stream parser expects symmetric
|
||||
# start/finish framing at both stream and step level — every
|
||||
# non-empty turn emits the pair. Without an opener, today's parser
|
||||
# tolerates the closer (no active parts to flush) but a future SDK
|
||||
# tightening would silently break the queue-mid-turn UX. Emit the
|
||||
# full empty pair so the contract stays correct.
|
||||
yield StreamStart(messageId=message_id).to_sse()
|
||||
yield StreamStartStep().to_sse()
|
||||
yield StreamFinishStep().to_sse()
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers=_ui_message_stream_headers(),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/stream",
|
||||
responses={
|
||||
404: {"description": "Session not found or access denied"},
|
||||
429: {"description": "Cost rate-limit or call-frequency cap exceeded"},
|
||||
},
|
||||
)
|
||||
async def stream_chat_post(
|
||||
session_id: str,
|
||||
request: StreamChatRequest,
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
):
|
||||
"""Start a new turn and return an AI SDK UI message stream.
|
||||
"""
|
||||
Stream chat responses for a session (POST with context support).
|
||||
|
||||
Returns an SSE stream (``text/event-stream``) with Vercel AI SDK chunks
|
||||
(text fragments, tool-call UI, tool results). The generation runs in a
|
||||
background task that survives client disconnects; reconnect via
|
||||
``GET /sessions/{session_id}/stream`` to resume.
|
||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||
- Text fragments as they are generated
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
|
||||
Follow-up messages typed while a turn is already running should use
|
||||
``POST /sessions/{session_id}/messages/pending``. If an older client still
|
||||
posts that follow-up here, we queue it defensively but still return a valid
|
||||
empty UI-message stream so AI SDK transports never receive a JSON body from
|
||||
the stream endpoint.
|
||||
The AI generation runs in a background task that continues even if the client disconnects.
|
||||
All chunks are written to a per-turn Redis stream for reconnection support. If the client
|
||||
disconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier.
|
||||
request: Request body with message, is_user_message, and optional context.
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
stream_start_time = time.perf_counter()
|
||||
# Wall-clock arrival time, propagated to the executor so the turn-start
|
||||
# drain can order pending messages relative to this request (pending
|
||||
# pushed BEFORE this instant were typed earlier; pending pushed AFTER
|
||||
# are race-path follow-ups typed while /stream was still processing).
|
||||
request_arrival_at = time.time()
|
||||
log_meta = {"component": "ChatStream", "session_id": session_id, "user_id": user_id}
|
||||
|
||||
logger.info(
|
||||
@@ -922,31 +801,7 @@ async def stream_chat_post(
|
||||
f"user={user_id}, message_len={len(request.message)}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
if (
|
||||
request.is_user_message
|
||||
and request.message
|
||||
and await is_turn_in_flight(session_id)
|
||||
):
|
||||
try:
|
||||
await queue_pending_for_http(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=request.message,
|
||||
context=request.context,
|
||||
file_ids=request.file_ids,
|
||||
)
|
||||
return _empty_ui_message_stream_response()
|
||||
except HTTPException as exc:
|
||||
if exc.status_code != 409:
|
||||
raise
|
||||
|
||||
# Permission resolution is only needed below for the actual turn — keep
|
||||
# it after the queue-fall-through so a queued mid-turn request returns
|
||||
# without paying the work.
|
||||
builder_permissions = resolve_session_permissions(session)
|
||||
|
||||
await _validate_and_get_session(session_id, user_id)
|
||||
logger.info(
|
||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
||||
extra={
|
||||
@@ -957,20 +812,18 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Pre-turn rate limit check (cost-based, microdollars).
|
||||
# Pre-turn rate limit check (token-based).
|
||||
# check_rate_limit short-circuits internally when both limits are 0.
|
||||
# Global defaults sourced from LaunchDarkly, falling back to config.
|
||||
if user_id:
|
||||
try:
|
||||
daily_limit, weekly_limit, _ = await get_global_rate_limits(
|
||||
user_id,
|
||||
config.daily_cost_limit_microdollars,
|
||||
config.weekly_cost_limit_microdollars,
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
await check_rate_limit(
|
||||
user_id=user_id,
|
||||
daily_cost_limit=daily_limit,
|
||||
weekly_cost_limit=weekly_limit,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
)
|
||||
except RateLimitExceeded as e:
|
||||
raise HTTPException(status_code=429, detail=str(e)) from e
|
||||
@@ -979,10 +832,33 @@ async def stream_chat_post(
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
sanitized_file_ids: list[str] | None = None
|
||||
if request.file_ids:
|
||||
files = await resolve_workspace_files(user_id, request.file_ids)
|
||||
sanitized_file_ids = [wf.id for wf in files] or None
|
||||
request.message += build_files_block(files)
|
||||
if request.file_ids and user_id:
|
||||
# Filter to valid UUIDs only to prevent DB abuse
|
||||
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
||||
|
||||
if valid_ids:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Batch query instead of N+1
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": valid_ids},
|
||||
"workspaceId": workspace.id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
# Only keep IDs that actually exist in the user's workspace
|
||||
sanitized_file_ids = [wf.id for wf in files] or None
|
||||
file_lines: list[str] = [
|
||||
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
||||
for wf in files
|
||||
]
|
||||
if file_lines:
|
||||
files_block = (
|
||||
"\n\n[Attached files]\n"
|
||||
+ "\n".join(file_lines)
|
||||
+ "\nUse read_workspace_file with the file_id to access file contents."
|
||||
)
|
||||
request.message += files_block
|
||||
|
||||
# Atomically append user message to session BEFORE creating task to avoid
|
||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||
@@ -1041,8 +917,6 @@ async def stream_chat_post(
|
||||
file_ids=sanitized_file_ids,
|
||||
mode=request.mode,
|
||||
model=request.model,
|
||||
permissions=builder_permissions,
|
||||
request_arrival_at=request_arrival_at,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
@@ -1184,62 +1058,12 @@ async def stream_chat_post(
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers=_ui_message_stream_headers(),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/messages/pending",
|
||||
response_model=QueuePendingMessageResponse,
|
||||
responses={
|
||||
404: {"description": "Session not found or access denied"},
|
||||
409: {"description": "Session has no active turn to receive pending messages"},
|
||||
429: {"description": "Call-frequency cap exceeded"},
|
||||
},
|
||||
)
|
||||
async def queue_pending_message(
|
||||
session_id: str,
|
||||
request: QueuePendingMessageRequest,
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
):
|
||||
"""Queue a follow-up message while the session has an active turn."""
|
||||
await _validate_and_get_session(session_id, user_id)
|
||||
if not await is_turn_in_flight(session_id):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Session has no active turn. Start a new turn with POST /stream.",
|
||||
)
|
||||
return await queue_pending_for_http(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=request.message,
|
||||
context=request.context,
|
||||
file_ids=request.file_ids,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}/messages/pending",
|
||||
response_model=PeekPendingMessagesResponse,
|
||||
responses={
|
||||
404: {"description": "Session not found or access denied"},
|
||||
},
|
||||
)
|
||||
async def get_pending_messages(
|
||||
session_id: str,
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
):
|
||||
"""Peek at the pending-message buffer without consuming it.
|
||||
|
||||
Returns the current contents of the session's pending message buffer
|
||||
so the frontend can restore the queued-message indicator after a page
|
||||
refresh and clear it correctly once a turn drains the buffer.
|
||||
"""
|
||||
await _validate_and_get_session(session_id, user_id)
|
||||
pending = await peek_pending_messages(session_id)
|
||||
return PeekPendingMessagesResponse(
|
||||
messages=[m.content for m in pending],
|
||||
count=len(pending),
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -1248,7 +1072,6 @@ async def get_pending_messages(
|
||||
)
|
||||
async def resume_session_stream(
|
||||
session_id: str,
|
||||
last_chunk_id: str | None = Query(default=None, include_in_schema=False),
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
@@ -1258,26 +1081,27 @@ async def resume_session_stream(
|
||||
Checks for an active (in-progress) task on the session and either replays
|
||||
the full SSE stream or returns 204 No Content if nothing is running.
|
||||
|
||||
Always replays the active turn from ``0-0``. The AI SDK UI-message parser
|
||||
keeps text/reasoning part state inside a single parser instance; resuming
|
||||
from a Redis cursor can skip the ``*-start`` events required by later
|
||||
``*-delta`` chunks.
|
||||
Args:
|
||||
session_id: The chat session identifier.
|
||||
user_id: Optional authenticated user ID.
|
||||
|
||||
Returns:
|
||||
StreamingResponse (SSE) when an active stream exists,
|
||||
or 204 No Content when there is nothing to resume.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
active_session, _latest_backend_id = await stream_registry.get_active_session(
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
|
||||
if not active_session:
|
||||
return Response(status_code=204)
|
||||
|
||||
if last_chunk_id:
|
||||
logger.info(
|
||||
"Ignoring deprecated last_chunk_id on stream resume",
|
||||
extra={"session_id": session_id, "last_chunk_id": last_chunk_id},
|
||||
)
|
||||
|
||||
# Always replay from the beginning ("0-0") on resume.
|
||||
# We can't use last_message_id because it's the latest ID in the backend
|
||||
# stream, not the latest the frontend received — the gap causes lost
|
||||
# messages. The frontend deduplicates replayed content.
|
||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
@@ -1338,7 +1162,12 @@ async def resume_session_stream(
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers=_ui_message_stream_headers(),
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -1494,7 +1323,6 @@ ToolResponseUnion = (
|
||||
| MemorySearchResponse
|
||||
| MemoryForgetCandidatesResponse
|
||||
| MemoryForgetConfirmResponse
|
||||
| TodoWriteResponse
|
||||
)
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,6 @@ allowing frontend code generators like Orval to create corresponding TypeScript
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.model import CredentialsType
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
@@ -48,57 +47,6 @@ class ProviderNamesResponse(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ProviderMetadata(BaseModel):
|
||||
"""Display metadata for a provider, shown in the settings integrations UI."""
|
||||
|
||||
name: str = Field(description="Provider slug (e.g. ``github``)")
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"One-line human-readable summary of what the provider does. "
|
||||
"Declared via ``ProviderBuilder.with_description(...)`` in the "
|
||||
"provider's ``_config.py``. ``None`` if not set."
|
||||
),
|
||||
)
|
||||
supported_auth_types: list[CredentialsType] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Credential types this provider accepts. Drives which connection "
|
||||
"tabs the settings UI renders for the provider. Empty list means "
|
||||
"no auth types declared."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_supported_auth_types(name: str) -> list[CredentialsType]:
|
||||
"""Return the provider's supported credential types from :class:`AutoRegistry`.
|
||||
|
||||
Populated by :meth:`ProviderBuilder.with_supported_auth_types` (or by
|
||||
``with_oauth`` / ``with_api_key`` / ``with_user_password`` when the provider
|
||||
uses the full builder chain). Returns an empty list for providers with no
|
||||
auth types declared.
|
||||
"""
|
||||
provider = AutoRegistry.get_provider(name)
|
||||
if provider is None:
|
||||
return []
|
||||
return sorted(provider.supported_auth_types)
|
||||
|
||||
|
||||
def get_provider_description(name: str) -> str | None:
|
||||
"""Return the provider's description from :class:`AutoRegistry`.
|
||||
|
||||
Descriptions are declared via ``ProviderBuilder.with_description(...)`` in
|
||||
the provider's ``_config.py`` (SDK path) or in
|
||||
``blocks/_static_provider_configs.py`` (for providers that don't yet have
|
||||
their own directory). Returns ``None`` for providers with no registered
|
||||
description.
|
||||
"""
|
||||
provider = AutoRegistry.get_provider(name)
|
||||
if provider is None:
|
||||
return None
|
||||
return provider.description
|
||||
|
||||
|
||||
class ProviderConstants(BaseModel):
|
||||
"""
|
||||
Model that exposes all provider names as a constant in the OpenAPI schema.
|
||||
|
||||
@@ -14,7 +14,7 @@ from fastapi import (
|
||||
Security,
|
||||
status,
|
||||
)
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, SecretStr, model_validator
|
||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
||||
|
||||
from backend.api.features.library.db import set_preset_webhook, update_preset
|
||||
@@ -29,14 +29,15 @@ from backend.data.integrations import (
|
||||
wait_for_webhook_event,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserIntegrations,
|
||||
is_sdk_default,
|
||||
)
|
||||
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
from backend.integrations.credentials_store import (
|
||||
@@ -47,14 +48,7 @@ from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
create_mcp_oauth_handler,
|
||||
)
|
||||
from backend.integrations.managed_credentials import (
|
||||
ensure_managed_credential,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
from backend.integrations.managed_providers.ayrshare import AyrshareManagedProvider
|
||||
from backend.integrations.managed_providers.ayrshare import (
|
||||
settings_available as ayrshare_settings_available,
|
||||
)
|
||||
from backend.integrations.managed_credentials import ensure_managed_credentials
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
@@ -66,14 +60,7 @@ from backend.util.exceptions import (
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .models import (
|
||||
ProviderConstants,
|
||||
ProviderMetadata,
|
||||
ProviderNamesResponse,
|
||||
get_all_provider_names,
|
||||
get_provider_description,
|
||||
get_supported_auth_types,
|
||||
)
|
||||
from .models import ProviderConstants, ProviderNamesResponse, get_all_provider_names
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.integrations.oauth import BaseOAuthHandler
|
||||
@@ -100,23 +87,14 @@ async def login(
|
||||
scopes: Annotated[
|
||||
str, Query(title="Comma-separated list of authorization scopes")
|
||||
] = "",
|
||||
credential_id: Annotated[
|
||||
str | None,
|
||||
Query(title="ID of existing credential to upgrade scopes for"),
|
||||
] = None,
|
||||
) -> LoginResponse:
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
|
||||
requested_scopes = scopes.split(",") if scopes else []
|
||||
|
||||
if credential_id:
|
||||
requested_scopes = await _prepare_scope_upgrade(
|
||||
user_id, provider, credential_id, requested_scopes
|
||||
)
|
||||
|
||||
# Generate and store a secure random state token along with the scopes
|
||||
state_token, code_challenge = await creds_manager.store.store_state_token(
|
||||
user_id, provider, requested_scopes, credential_id=credential_id
|
||||
user_id, provider, requested_scopes
|
||||
)
|
||||
login_url = handler.get_login_url(
|
||||
requested_scopes, state_token, code_challenge=code_challenge
|
||||
@@ -238,9 +216,7 @@ async def callback(
|
||||
)
|
||||
|
||||
# TODO: Allow specifying `title` to set on `credentials`
|
||||
credentials = await _merge_or_create_credential(
|
||||
user_id, provider, credentials, valid_state.credential_id
|
||||
)
|
||||
await creds_manager.create(user_id, credentials)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully processed OAuth callback for user {user_id} "
|
||||
@@ -250,38 +226,13 @@ async def callback(
|
||||
return to_meta_response(credentials)
|
||||
|
||||
|
||||
# Bound the first-time sweep so a slow upstream (e.g. Ayrshare) can't hang
|
||||
# the credential-list endpoint. On timeout we still kick off a fire-and-
|
||||
# forget sweep so provisioning eventually completes; the user just won't
|
||||
# see the managed cred until the next refresh.
|
||||
_MANAGED_PROVISION_TIMEOUT_S = 10.0
|
||||
|
||||
|
||||
async def _ensure_managed_credentials_bounded(user_id: str) -> None:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
ensure_managed_credentials(user_id, creds_manager.store),
|
||||
timeout=_MANAGED_PROVISION_TIMEOUT_S,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Managed credential sweep exceeded %.1fs for user=%s; "
|
||||
"continuing without it — provisioning will complete in background",
|
||||
_MANAGED_PROVISION_TIMEOUT_S,
|
||||
user_id,
|
||||
)
|
||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||
|
||||
|
||||
@router.get("/credentials", summary="List Credentials")
|
||||
async def list_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
# Block on provisioning so managed credentials appear on the first load
|
||||
# instead of after a refresh, but with a timeout so a slow upstream
|
||||
# can't hang the endpoint. `_provisioned_users` short-circuits on
|
||||
# repeat calls.
|
||||
await _ensure_managed_credentials_bounded(user_id)
|
||||
# Fire-and-forget: provision missing managed credentials in the background.
|
||||
# The credential appears on the next page load; listing is never blocked.
|
||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||
credentials = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
return [
|
||||
@@ -296,7 +247,7 @@ async def list_credentials_by_provider(
|
||||
],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
await _ensure_managed_credentials_bounded(user_id)
|
||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
|
||||
return [
|
||||
@@ -330,115 +281,6 @@ async def get_credential(
|
||||
return to_meta_response(credential)
|
||||
|
||||
|
||||
class PickerTokenResponse(BaseModel):
|
||||
"""Short-lived OAuth access token shipped to the browser for rendering a
|
||||
provider-hosted picker UI (e.g. Google Drive Picker). Deliberately narrow:
|
||||
only the fields the client needs to initialize the picker widget. Issued
|
||||
from the user's own stored credential so ownership and scope gating are
|
||||
enforced by the credential lookup."""
|
||||
|
||||
access_token: str = Field(
|
||||
description="OAuth access token suitable for the picker SDK call."
|
||||
)
|
||||
access_token_expires_at: int | None = Field(
|
||||
default=None,
|
||||
description="Unix timestamp at which the access token expires, if known.",
|
||||
)
|
||||
|
||||
|
||||
# Allowlist of (provider, scopes) tuples that may mint picker tokens. Only
|
||||
# Drive-picker-capable scopes qualify so a caller can't use this endpoint to
|
||||
# extract a GitHub / other-provider OAuth token for unrelated purposes. If a
|
||||
# future provider integrates a hosted picker that needs a raw access token,
|
||||
# add its specific picker-relevant scopes here.
|
||||
_PICKER_TOKEN_ALLOWED_SCOPES: dict[ProviderName, frozenset[str]] = {
|
||||
ProviderName.GOOGLE: frozenset(
|
||||
[
|
||||
"https://www.googleapis.com/auth/drive.file",
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/drive",
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{provider}/credentials/{cred_id}/picker-token",
|
||||
summary="Issue a short-lived access token for a provider-hosted picker",
|
||||
operation_id="postV1GetPickerToken",
|
||||
)
|
||||
async def get_picker_token(
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider that owns the credentials")
|
||||
],
|
||||
cred_id: Annotated[
|
||||
str, Path(title="The ID of the OAuth2 credentials to mint a token from")
|
||||
],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> PickerTokenResponse:
|
||||
"""Return the raw access token for an OAuth2 credential so the frontend
|
||||
can initialize a provider-hosted picker (e.g. Google Drive Picker).
|
||||
|
||||
`GET /{provider}/credentials/{cred_id}` deliberately strips secrets (see
|
||||
`CredentialsMetaResponse` + `TestGetCredentialReturnsMetaOnly` in
|
||||
`router_test.py`). That hardening broke the Drive picker, which needs the
|
||||
raw access token to call `google.picker.Builder.setOAuthToken(...)`. This
|
||||
endpoint carves a narrow, explicit hole: the caller must own the
|
||||
credential, it must be OAuth2, and the endpoint returns only the access
|
||||
token + its expiry — nothing else about the credential. SDK-default
|
||||
credentials are excluded for the same reason as `get_credential`.
|
||||
"""
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
|
||||
credential = await creds_manager.get(user_id, cred_id)
|
||||
if not credential:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if not provider_matches(credential.provider, provider):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if not isinstance(credential, OAuth2Credentials):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Picker tokens are only available for OAuth2 credentials",
|
||||
)
|
||||
if not credential.access_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Credential has no access token; reconnect the account",
|
||||
)
|
||||
|
||||
# Gate on provider+scope: only credentials that actually grant access to
|
||||
# a provider-hosted picker flow may mint a token through this endpoint.
|
||||
# Prevents using this path to extract bearer tokens for unrelated OAuth
|
||||
# integrations (e.g. GitHub) that happen to be stored under the same user.
|
||||
allowed_scopes = _PICKER_TOKEN_ALLOWED_SCOPES.get(provider)
|
||||
if not allowed_scopes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(f"Picker tokens are not available for provider '{provider.value}'"),
|
||||
)
|
||||
cred_scopes = set(credential.scopes or [])
|
||||
if cred_scopes.isdisjoint(allowed_scopes):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
"Credential does not grant any scope eligible for the picker. "
|
||||
"Reconnect with the appropriate scope."
|
||||
),
|
||||
)
|
||||
|
||||
return PickerTokenResponse(
|
||||
access_token=credential.access_token.get_secret_value(),
|
||||
access_token_expires_at=credential.access_token_expires_at,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{provider}/credentials", status_code=201, summary="Create Credentials")
|
||||
async def create_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
@@ -732,186 +574,6 @@ async def _execute_webhook_preset_trigger(
|
||||
# Continue processing - webhook should be resilient to individual failures
|
||||
|
||||
|
||||
# -------------------- INCREMENTAL AUTH HELPERS -------------------- #
|
||||
|
||||
|
||||
async def _prepare_scope_upgrade(
|
||||
user_id: str,
|
||||
provider: ProviderName,
|
||||
credential_id: str,
|
||||
requested_scopes: list[str],
|
||||
) -> list[str]:
|
||||
"""Validate an existing credential for scope upgrade and compute scopes.
|
||||
|
||||
For providers without native incremental auth (e.g. GitHub), returns the
|
||||
union of existing + requested scopes. For providers that handle merging
|
||||
server-side (e.g. Google with ``include_granted_scopes``), returns the
|
||||
requested scopes unchanged.
|
||||
|
||||
Raises HTTPException on validation failure.
|
||||
"""
|
||||
# Platform-owned system credentials must never be upgraded — scope
|
||||
# changes here would leak across every user that shares them.
|
||||
if is_system_credential(credential_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="System credentials cannot be upgraded",
|
||||
)
|
||||
|
||||
existing = await creds_manager.store.get_creds_by_id(user_id, credential_id)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credential to upgrade not found",
|
||||
)
|
||||
if not isinstance(existing, OAuth2Credentials):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Only OAuth2 credentials can be upgraded",
|
||||
)
|
||||
if not provider_matches(existing.provider, provider.value):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Credential provider does not match the requested provider",
|
||||
)
|
||||
if existing.is_managed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Managed credentials cannot be upgraded",
|
||||
)
|
||||
|
||||
# Google handles scope merging via include_granted_scopes; others need
|
||||
# the union of existing + new scopes in the login URL.
|
||||
if provider != ProviderName.GOOGLE:
|
||||
requested_scopes = list(set(requested_scopes) | set(existing.scopes))
|
||||
|
||||
return requested_scopes
|
||||
|
||||
|
||||
async def _merge_or_create_credential(
|
||||
user_id: str,
|
||||
provider: ProviderName,
|
||||
credentials: OAuth2Credentials,
|
||||
credential_id: str | None,
|
||||
) -> OAuth2Credentials:
|
||||
"""Either upgrade an existing credential or create a new one.
|
||||
|
||||
When *credential_id* is set (explicit upgrade), merges scopes and updates
|
||||
the existing credential. Otherwise, checks for an implicit merge (same
|
||||
provider + username) before falling back to creating a new credential.
|
||||
"""
|
||||
if credential_id:
|
||||
return await _upgrade_existing_credential(user_id, credential_id, credentials)
|
||||
|
||||
# Implicit merge: check for existing credential with same provider+username.
|
||||
# Skip managed/system credentials and require a non-None username on both
|
||||
# sides so we never accidentally merge unrelated credentials.
|
||||
if credentials.username is None:
|
||||
await creds_manager.create(user_id, credentials)
|
||||
return credentials
|
||||
|
||||
existing_creds = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
matching = next(
|
||||
(
|
||||
c
|
||||
for c in existing_creds
|
||||
if isinstance(c, OAuth2Credentials)
|
||||
and not c.is_managed
|
||||
and not is_system_credential(c.id)
|
||||
and c.username is not None
|
||||
and c.username == credentials.username
|
||||
),
|
||||
None,
|
||||
)
|
||||
if matching:
|
||||
# Only merge into the existing credential when the new token
|
||||
# already covers every scope we're about to advertise on it.
|
||||
# Without this guard we'd overwrite ``matching.access_token`` with
|
||||
# a narrower token while storing a wider ``scopes`` list — the
|
||||
# record would claim authorizations the token does not grant, and
|
||||
# blocks using the lost scopes would fail with opaque 401/403s
|
||||
# until the user hits re-auth. On a narrowing login, keep the
|
||||
# two credentials separate instead.
|
||||
if set(credentials.scopes).issuperset(set(matching.scopes)):
|
||||
return await _upgrade_existing_credential(user_id, matching.id, credentials)
|
||||
|
||||
await creds_manager.create(user_id, credentials)
|
||||
return credentials
|
||||
|
||||
|
||||
async def _upgrade_existing_credential(
|
||||
user_id: str,
|
||||
existing_cred_id: str,
|
||||
new_credentials: OAuth2Credentials,
|
||||
) -> OAuth2Credentials:
|
||||
"""Merge scopes from *new_credentials* into an existing credential."""
|
||||
# Defense-in-depth: re-check system and provider invariants right before
|
||||
# the write. The login-time check in `_prepare_scope_upgrade` can go stale
|
||||
# by the time the callback runs, and the implicit-merge path bypasses
|
||||
# login-time validation entirely, so every write-path must enforce these
|
||||
# on its own.
|
||||
if is_system_credential(existing_cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="System credentials cannot be upgraded",
|
||||
)
|
||||
existing = await creds_manager.store.get_creds_by_id(user_id, existing_cred_id)
|
||||
if not existing or not isinstance(existing, OAuth2Credentials):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Credential to upgrade not found",
|
||||
)
|
||||
if existing.is_managed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Managed credentials cannot be upgraded",
|
||||
)
|
||||
if not provider_matches(existing.provider, new_credentials.provider):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Credential provider does not match the requested provider",
|
||||
)
|
||||
|
||||
if (
|
||||
existing.username
|
||||
and new_credentials.username
|
||||
and existing.username != new_credentials.username
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username mismatch: authenticated as a different user",
|
||||
)
|
||||
|
||||
# Operate on a copy so the caller's ``new_credentials`` object is not
|
||||
# mutated out from under them. Every caller today immediately discards
|
||||
# or replaces its reference, but the implicit-merge path in
|
||||
# ``_merge_or_create_credential`` reads ``credentials.scopes`` before
|
||||
# calling into us — a future reader after the call would otherwise
|
||||
# silently see the overwritten values.
|
||||
merged = new_credentials.model_copy(deep=True)
|
||||
merged.id = existing.id
|
||||
merged.title = existing.title
|
||||
merged.scopes = list(set(existing.scopes) | set(new_credentials.scopes))
|
||||
merged.metadata = {
|
||||
**(existing.metadata or {}),
|
||||
**(new_credentials.metadata or {}),
|
||||
}
|
||||
# Preserve the existing refresh_token and username if the incremental
|
||||
# response doesn't carry them. Providers like Google only return a
|
||||
# refresh_token on first authorization — dropping it here would orphan
|
||||
# the credential on the next access-token expiry, forcing the user to
|
||||
# re-auth from scratch. Username is similarly sticky: if we've already
|
||||
# resolved it for this credential, keep it rather than silently
|
||||
# blanking it on an incremental upgrade.
|
||||
if not merged.refresh_token and existing.refresh_token:
|
||||
merged.refresh_token = existing.refresh_token
|
||||
merged.refresh_token_expires_at = existing.refresh_token_expires_at
|
||||
if not merged.username and existing.username:
|
||||
merged.username = existing.username
|
||||
await creds_manager.update(user_id, merged)
|
||||
return merged
|
||||
|
||||
|
||||
# --------------------------- UTILITIES ---------------------------- #
|
||||
|
||||
|
||||
@@ -1122,21 +784,12 @@ def _get_provider_oauth_handler(
|
||||
async def get_ayrshare_sso_url(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> AyrshareSSOResponse:
|
||||
"""Generate a JWT SSO URL so the user can link their social accounts.
|
||||
|
||||
The per-user Ayrshare profile key is provisioned and persisted as a
|
||||
standard ``is_managed=True`` credential by
|
||||
:class:`~backend.integrations.managed_providers.ayrshare.AyrshareManagedProvider`.
|
||||
This endpoint only signs a short-lived JWT pointing at the Ayrshare-
|
||||
hosted social-linking page; all profile lifecycle logic lives with the
|
||||
managed provider.
|
||||
"""
|
||||
if not ayrshare_settings_available():
|
||||
raise HTTPException(
|
||||
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Ayrshare integration is not configured",
|
||||
)
|
||||
Generate an SSO URL for Ayrshare social media integration.
|
||||
|
||||
Returns:
|
||||
dict: Contains the SSO URL for Ayrshare integration
|
||||
"""
|
||||
try:
|
||||
client = AyrshareClient()
|
||||
except MissingConfigError:
|
||||
@@ -1145,63 +798,66 @@ async def get_ayrshare_sso_url(
|
||||
detail="Ayrshare integration is not configured",
|
||||
)
|
||||
|
||||
# On-demand provisioning: AyrshareManagedProvider opts out of the
|
||||
# credentials sweep (profile quota is per-user subscription-bound). This
|
||||
# endpoint is the only trigger that provisions a profile — one Ayrshare
|
||||
# profile per user who actually opens the connect flow, not one per
|
||||
# every authenticated user.
|
||||
provisioned = await ensure_managed_credential(
|
||||
user_id, creds_manager.store, AyrshareManagedProvider()
|
||||
)
|
||||
if not provisioned:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_502_BAD_GATEWAY,
|
||||
detail="Failed to provision Ayrshare profile",
|
||||
)
|
||||
# Ayrshare profile key is stored in the credentials store
|
||||
# It is generated when creating a new profile, if there is no profile key,
|
||||
# we create a new profile and store the profile key in the credentials store
|
||||
|
||||
ayrshare_creds = [
|
||||
c
|
||||
for c in await creds_manager.store.get_creds_by_provider(user_id, "ayrshare")
|
||||
if c.is_managed and isinstance(c, APIKeyCredentials)
|
||||
]
|
||||
if not ayrshare_creds:
|
||||
logger.error(
|
||||
"Ayrshare credential provisioning did not produce a credential "
|
||||
"for user %s",
|
||||
user_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_502_BAD_GATEWAY,
|
||||
detail="Failed to provision Ayrshare profile",
|
||||
)
|
||||
profile_key_str = ayrshare_creds[0].api_key.get_secret_value()
|
||||
user_integrations: UserIntegrations = await get_user_integrations(user_id)
|
||||
profile_key = user_integrations.managed_credentials.ayrshare_profile_key
|
||||
|
||||
if not profile_key:
|
||||
logger.debug(f"Creating new Ayrshare profile for user {user_id}")
|
||||
try:
|
||||
profile = await client.create_profile(
|
||||
title=f"User {user_id}", messaging_active=True
|
||||
)
|
||||
profile_key = profile.profileKey
|
||||
await creds_manager.store.set_ayrshare_profile_key(user_id, profile_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating Ayrshare profile for user {user_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_502_BAD_GATEWAY,
|
||||
detail="Failed to create Ayrshare profile",
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Using existing Ayrshare profile for user {user_id}")
|
||||
|
||||
profile_key_str = (
|
||||
profile_key.get_secret_value()
|
||||
if isinstance(profile_key, SecretStr)
|
||||
else str(profile_key)
|
||||
)
|
||||
|
||||
private_key = settings.secrets.ayrshare_jwt_key
|
||||
# Ayrshare JWT max lifetime is 2880 minutes (48 h).
|
||||
# Ayrshare JWT expiry is 2880 minutes (48 hours)
|
||||
max_expiry_minutes = 2880
|
||||
try:
|
||||
logger.debug(f"Generating Ayrshare JWT for user {user_id}")
|
||||
jwt_response = await client.generate_jwt(
|
||||
private_key=private_key,
|
||||
profile_key=profile_key_str,
|
||||
# `allowed_social` is the set of networks the Ayrshare-hosted
|
||||
# social-linking page will *offer* the user to connect. Blocks
|
||||
# exist for more platforms than are listed here; the list is
|
||||
# deliberately narrower so the rollout can verify each network
|
||||
# end-to-end before widening the user-visible surface. Keep
|
||||
# in sync with tested platforms — extend as each is verified
|
||||
# against the block + Ayrshare's network-specific quirks.
|
||||
allowed_social=[
|
||||
# NOTE: We are enabling platforms one at a time
|
||||
# to speed up the development process
|
||||
# SocialPlatform.FACEBOOK,
|
||||
SocialPlatform.TWITTER,
|
||||
SocialPlatform.LINKEDIN,
|
||||
SocialPlatform.INSTAGRAM,
|
||||
SocialPlatform.YOUTUBE,
|
||||
# SocialPlatform.REDDIT,
|
||||
# SocialPlatform.TELEGRAM,
|
||||
# SocialPlatform.GOOGLE_MY_BUSINESS,
|
||||
# SocialPlatform.PINTEREST,
|
||||
SocialPlatform.TIKTOK,
|
||||
# SocialPlatform.BLUESKY,
|
||||
# SocialPlatform.SNAPCHAT,
|
||||
# SocialPlatform.THREADS,
|
||||
],
|
||||
expires_in=max_expiry_minutes,
|
||||
verify=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Error generating Ayrshare JWT for user %s: %s", user_id, exc)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating Ayrshare JWT for user {user_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_502_BAD_GATEWAY, detail="Failed to generate JWT"
|
||||
)
|
||||
@@ -1211,37 +867,20 @@ async def get_ayrshare_sso_url(
|
||||
|
||||
|
||||
# === PROVIDER DISCOVERY ENDPOINTS ===
|
||||
@router.get("/providers", response_model=List[ProviderMetadata])
|
||||
async def list_providers() -> List[ProviderMetadata]:
|
||||
@router.get("/providers", response_model=List[str])
|
||||
async def list_providers() -> List[str]:
|
||||
"""
|
||||
Get metadata for every available provider.
|
||||
Get a list of all available provider names.
|
||||
|
||||
Returns both statically defined providers (from ``ProviderName`` enum) and
|
||||
dynamically registered providers (from SDK decorators). Each entry includes
|
||||
a ``description`` declared via ``ProviderBuilder.with_description(...)`` in
|
||||
the provider's ``_config.py``.
|
||||
Returns both statically defined providers (from ProviderName enum)
|
||||
and dynamically registered providers (from SDK decorators).
|
||||
|
||||
Note: The complete list of provider names is also available as a constant
|
||||
in the generated TypeScript client via PROVIDER_NAMES.
|
||||
"""
|
||||
# Ensure all block modules (and therefore every provider's _config.py) are
|
||||
# imported before we read from AutoRegistry. Cached on first call.
|
||||
try:
|
||||
from backend.blocks import load_all_blocks
|
||||
|
||||
load_all_blocks()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load blocks for provider metadata: {e}")
|
||||
|
||||
# Get all providers at runtime
|
||||
all_providers = get_all_provider_names()
|
||||
return [
|
||||
ProviderMetadata(
|
||||
name=name,
|
||||
description=get_provider_description(name),
|
||||
supported_auth_types=get_supported_auth_types(name),
|
||||
)
|
||||
for name in all_providers
|
||||
]
|
||||
return all_providers
|
||||
|
||||
|
||||
@router.get("/providers/system", response_model=List[str])
|
||||
|
||||
@@ -393,7 +393,7 @@ class TestEnsureManagedCredentials:
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_awaited_once_with("user-1", store)
|
||||
provider.provision.assert_awaited_once_with("user-1")
|
||||
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -568,181 +568,3 @@ class TestCleanupManagedCredentials:
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
# No exception raised — cleanup failure is swallowed.
|
||||
|
||||
|
||||
class TestGetPickerToken:
|
||||
"""POST /{provider}/credentials/{cred_id}/picker-token must:
|
||||
1. Return the access token for OAuth2 creds the caller owns.
|
||||
2. 404 for non-owned, non-existent, or wrong-provider creds.
|
||||
3. 400 for non-OAuth2 creds (API key, host-scoped, user/password).
|
||||
4. 404 for SDK default creds (same hardening as get_credential).
|
||||
5. Preserve the `TestGetCredentialReturnsMetaOnly` contract — the
|
||||
existing meta-only endpoint must still strip secrets even after
|
||||
this picker-token endpoint exists."""
|
||||
|
||||
def test_oauth2_owner_gets_access_token(self):
|
||||
# Use a Google cred with a drive.file scope — only picker-eligible
|
||||
# (provider, scope) pairs can mint a token. GitHub-style creds are
|
||||
# explicitly rejected; see `test_non_picker_provider_rejected_as_400`.
|
||||
cred = _make_oauth2_cred(
|
||||
cred_id="cred-gdrive",
|
||||
provider="google",
|
||||
)
|
||||
cred.scopes = ["https://www.googleapis.com/auth/drive.file"]
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/google/credentials/cred-gdrive/picker-token")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
# The whole point of this endpoint: the access token IS returned here.
|
||||
assert data["access_token"] == "ghp_secret_token"
|
||||
# Only the two declared fields come back — nothing else leaks.
|
||||
assert set(data.keys()) <= {"access_token", "access_token_expires_at"}
|
||||
|
||||
def test_non_picker_provider_rejected_as_400(self):
|
||||
"""Provider allowlist: even with a valid OAuth2 credential, a
|
||||
non-picker provider (GitHub, etc.) cannot mint a picker token.
|
||||
Stops this endpoint from being used as a generic bearer-token
|
||||
extraction path for any stored OAuth cred under the same user."""
|
||||
cred = _make_oauth2_cred(provider="github")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/github/credentials/cred-456/picker-token")
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert "not available for provider" in resp.json()["detail"]
|
||||
assert "ghp_secret_token" not in str(resp.json())
|
||||
|
||||
def test_google_oauth_without_drive_scope_rejected(self):
|
||||
"""Scope allowlist: a Google OAuth2 cred that only carries non-picker
|
||||
scopes (e.g. gmail.readonly, calendar) cannot mint a picker token.
|
||||
Forces the frontend to reconnect with a Drive scope before the
|
||||
picker is available."""
|
||||
cred = _make_oauth2_cred(provider="google")
|
||||
cred.scopes = [
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/calendar",
|
||||
]
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/google/credentials/cred-456/picker-token")
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert "picker" in resp.json()["detail"].lower()
|
||||
|
||||
def test_api_key_credential_rejected_as_400(self):
|
||||
cred = _make_api_key_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/openai/credentials/cred-123/picker-token")
|
||||
|
||||
assert resp.status_code == 400
|
||||
# API keys must not silently fall through to a 200 response of some
|
||||
# other shape — the client should see a clear shape rejection.
|
||||
body = str(resp.json())
|
||||
assert "sk-secret-key-value" not in body
|
||||
|
||||
def test_user_password_credential_rejected_as_400(self):
|
||||
cred = _make_user_password_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/openai/credentials/cred-789/picker-token")
|
||||
|
||||
assert resp.status_code == 400
|
||||
body = str(resp.json())
|
||||
assert "s3cret-pass" not in body
|
||||
assert "admin" not in body
|
||||
|
||||
def test_host_scoped_credential_rejected_as_400(self):
|
||||
cred = _make_host_scoped_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/openai/credentials/cred-host/picker-token")
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert "top-secret" not in str(resp.json())
|
||||
|
||||
def test_missing_credential_returns_404(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=None)
|
||||
resp = client.post("/github/credentials/nonexistent/picker-token")
|
||||
|
||||
assert resp.status_code == 404
|
||||
assert resp.json()["detail"] == "Credentials not found"
|
||||
|
||||
def test_wrong_provider_returns_404(self):
|
||||
"""Symmetric with get_credential: provider mismatch is a generic
|
||||
404, not a 400, so we don't leak existence of a credential the
|
||||
caller doesn't own on that provider."""
|
||||
cred = _make_oauth2_cred(provider="github")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/google/credentials/cred-456/picker-token")
|
||||
|
||||
assert resp.status_code == 404
|
||||
assert resp.json()["detail"] == "Credentials not found"
|
||||
|
||||
def test_sdk_default_returns_404(self):
|
||||
"""SDK defaults are invisible to the user-facing API — picker-token
|
||||
must not mint a token for them either."""
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock()
|
||||
resp = client.post("/openai/credentials/openai-default/picker-token")
|
||||
|
||||
assert resp.status_code == 404
|
||||
mock_mgr.get.assert_not_called()
|
||||
|
||||
def test_oauth2_without_access_token_returns_400(self):
|
||||
"""A stored OAuth2 cred whose access_token is missing can't satisfy
|
||||
a picker init. Surface a clear reconnect instruction rather than
|
||||
returning an empty string."""
|
||||
cred = _make_oauth2_cred()
|
||||
# Simulate a cred that lost its access token
|
||||
object.__setattr__(cred, "access_token", None)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.post("/github/credentials/cred-456/picker-token")
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert "reconnect" in resp.json()["detail"].lower()
|
||||
|
||||
def test_meta_only_endpoint_still_strips_access_token(self):
|
||||
"""Regression guard for the coexistence contract: the new
|
||||
picker-token endpoint must NOT accidentally leak the token through
|
||||
the meta-only GET endpoint. TestGetCredentialReturnsMetaOnly
|
||||
covers this more broadly; this is a fast sanity check co-located
|
||||
with the new endpoint's tests."""
|
||||
cred = _make_oauth2_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/github/credentials/cred-456")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "access_token" not in body
|
||||
assert "refresh_token" not in body
|
||||
assert "ghp_secret_token" not in str(body)
|
||||
|
||||
@@ -743,7 +743,6 @@ async def update_library_agent_version_and_settings(
|
||||
graph=agent_graph,
|
||||
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
||||
builder_chat_session_id=library.settings.builder_chat_session_id,
|
||||
)
|
||||
if updated_settings != library.settings:
|
||||
library = await update_library_agent(
|
||||
@@ -1804,7 +1803,7 @@ async def create_preset_from_graph_execution(
|
||||
raise NotFoundError(
|
||||
f"Graph #{graph_execution.graph_id} not found or accessible"
|
||||
)
|
||||
elif len(graph.regular_credentials_inputs) > 0:
|
||||
elif len(graph.aggregate_credentials_inputs()) > 0:
|
||||
raise ValueError(
|
||||
f"Graph execution #{graph_exec_id} can't be turned into a preset "
|
||||
"because it was run before this feature existed "
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Platform bot linking — user-facing REST routes."""
|
||||
@@ -1,158 +0,0 @@
|
||||
"""User-facing platform_linking REST routes (JWT auth)."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, HTTPException, Path, Security
|
||||
|
||||
from backend.data.db_accessors import platform_linking_db
|
||||
from backend.platform_linking.models import (
|
||||
ConfirmLinkResponse,
|
||||
ConfirmUserLinkResponse,
|
||||
DeleteLinkResponse,
|
||||
LinkTokenInfoResponse,
|
||||
PlatformLinkInfo,
|
||||
PlatformUserLinkInfo,
|
||||
)
|
||||
from backend.util.exceptions import (
|
||||
LinkAlreadyExistsError,
|
||||
LinkFlowMismatchError,
|
||||
LinkTokenExpiredError,
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
TokenPath = Annotated[
|
||||
str,
|
||||
Path(max_length=64, pattern=r"^[A-Za-z0-9_-]+$"),
|
||||
]
|
||||
|
||||
|
||||
def _translate(exc: Exception) -> HTTPException:
|
||||
if isinstance(exc, NotFoundError):
|
||||
return HTTPException(status_code=404, detail=str(exc))
|
||||
if isinstance(exc, NotAuthorizedError):
|
||||
return HTTPException(status_code=403, detail=str(exc))
|
||||
if isinstance(exc, LinkAlreadyExistsError):
|
||||
return HTTPException(status_code=409, detail=str(exc))
|
||||
if isinstance(exc, LinkTokenExpiredError):
|
||||
return HTTPException(status_code=410, detail=str(exc))
|
||||
if isinstance(exc, LinkFlowMismatchError):
|
||||
return HTTPException(status_code=400, detail=str(exc))
|
||||
return HTTPException(status_code=500, detail="Internal error.")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tokens/{token}/info",
|
||||
response_model=LinkTokenInfoResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Get display info for a link token",
|
||||
)
|
||||
async def get_link_token_info_route(token: TokenPath) -> LinkTokenInfoResponse:
|
||||
try:
|
||||
return await platform_linking_db().get_link_token_info(token)
|
||||
except (NotFoundError, LinkTokenExpiredError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tokens/{token}/confirm",
|
||||
response_model=ConfirmLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Confirm a SERVER link token (user must be authenticated)",
|
||||
)
|
||||
async def confirm_link_token(
|
||||
token: TokenPath,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> ConfirmLinkResponse:
|
||||
try:
|
||||
return await platform_linking_db().confirm_server_link(token, user_id)
|
||||
except (
|
||||
NotFoundError,
|
||||
LinkFlowMismatchError,
|
||||
LinkTokenExpiredError,
|
||||
LinkAlreadyExistsError,
|
||||
) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/user-tokens/{token}/confirm",
|
||||
response_model=ConfirmUserLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Confirm a USER link token (user must be authenticated)",
|
||||
)
|
||||
async def confirm_user_link_token(
|
||||
token: TokenPath,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> ConfirmUserLinkResponse:
|
||||
try:
|
||||
return await platform_linking_db().confirm_user_link(token, user_id)
|
||||
except (
|
||||
NotFoundError,
|
||||
LinkFlowMismatchError,
|
||||
LinkTokenExpiredError,
|
||||
LinkAlreadyExistsError,
|
||||
) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/links",
|
||||
response_model=list[PlatformLinkInfo],
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="List all platform servers linked to the authenticated user",
|
||||
)
|
||||
async def list_my_links(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> list[PlatformLinkInfo]:
|
||||
return await platform_linking_db().list_server_links(user_id)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/user-links",
|
||||
response_model=list[PlatformUserLinkInfo],
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="List all DM links for the authenticated user",
|
||||
)
|
||||
async def list_my_user_links(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> list[PlatformUserLinkInfo]:
|
||||
return await platform_linking_db().list_user_links(user_id)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/links/{link_id}",
|
||||
response_model=DeleteLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Unlink a platform server",
|
||||
)
|
||||
async def delete_link(
|
||||
link_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> DeleteLinkResponse:
|
||||
try:
|
||||
return await platform_linking_db().delete_server_link(link_id, user_id)
|
||||
except (NotFoundError, NotAuthorizedError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/user-links/{link_id}",
|
||||
response_model=DeleteLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Unlink a DM / user link",
|
||||
)
|
||||
async def delete_user_link_route(
|
||||
link_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> DeleteLinkResponse:
|
||||
try:
|
||||
return await platform_linking_db().delete_user_link(link_id, user_id)
|
||||
except (NotFoundError, NotAuthorizedError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
@@ -1,264 +0,0 @@
|
||||
"""Route tests: domain exceptions → HTTPException status codes."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from backend.util.exceptions import (
|
||||
LinkAlreadyExistsError,
|
||||
LinkFlowMismatchError,
|
||||
LinkTokenExpiredError,
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
)
|
||||
|
||||
|
||||
def _db_mock(**method_configs):
|
||||
"""Return a mock of the accessor's return value with the given AsyncMocks."""
|
||||
db = MagicMock()
|
||||
for name, mock in method_configs.items():
|
||||
setattr(db, name, mock)
|
||||
return db
|
||||
|
||||
|
||||
class TestTokenInfoRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_maps_to_404(self):
|
||||
from backend.api.features.platform_linking.routes import (
|
||||
get_link_token_info_route,
|
||||
)
|
||||
|
||||
db = _db_mock(
|
||||
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await get_link_token_info_route(token="abc")
|
||||
assert exc.value.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_maps_to_410(self):
|
||||
from backend.api.features.platform_linking.routes import (
|
||||
get_link_token_info_route,
|
||||
)
|
||||
|
||||
db = _db_mock(
|
||||
get_link_token_info=AsyncMock(side_effect=LinkTokenExpiredError("expired"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await get_link_token_info_route(token="abc")
|
||||
assert exc.value.status_code == 410
|
||||
|
||||
|
||||
class TestConfirmLinkRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"exc,expected_status",
|
||||
[
|
||||
(NotFoundError("missing"), 404),
|
||||
(LinkFlowMismatchError("wrong flow"), 400),
|
||||
(LinkTokenExpiredError("expired"), 410),
|
||||
(LinkAlreadyExistsError("already"), 409),
|
||||
],
|
||||
)
|
||||
async def test_translation(self, exc: Exception, expected_status: int):
|
||||
from backend.api.features.platform_linking.routes import confirm_link_token
|
||||
|
||||
db = _db_mock(confirm_server_link=AsyncMock(side_effect=exc))
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as ctx:
|
||||
await confirm_link_token(token="abc", user_id="u1")
|
||||
assert ctx.value.status_code == expected_status
|
||||
|
||||
|
||||
class TestConfirmUserLinkRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"exc,expected_status",
|
||||
[
|
||||
(NotFoundError("missing"), 404),
|
||||
(LinkFlowMismatchError("wrong flow"), 400),
|
||||
(LinkTokenExpiredError("expired"), 410),
|
||||
(LinkAlreadyExistsError("already"), 409),
|
||||
],
|
||||
)
|
||||
async def test_translation(self, exc: Exception, expected_status: int):
|
||||
from backend.api.features.platform_linking.routes import confirm_user_link_token
|
||||
|
||||
db = _db_mock(confirm_user_link=AsyncMock(side_effect=exc))
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as ctx:
|
||||
await confirm_user_link_token(token="abc", user_id="u1")
|
||||
assert ctx.value.status_code == expected_status
|
||||
|
||||
|
||||
class TestDeleteLinkRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_maps_to_404(self):
|
||||
from backend.api.features.platform_linking.routes import delete_link
|
||||
|
||||
db = _db_mock(
|
||||
delete_server_link=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await delete_link(link_id="x", user_id="u1")
|
||||
assert exc.value.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_owned_maps_to_403(self):
|
||||
from backend.api.features.platform_linking.routes import delete_link
|
||||
|
||||
db = _db_mock(
|
||||
delete_server_link=AsyncMock(side_effect=NotAuthorizedError("nope"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await delete_link(link_id="x", user_id="u1")
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
class TestDeleteUserLinkRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_maps_to_404(self):
|
||||
from backend.api.features.platform_linking.routes import delete_user_link_route
|
||||
|
||||
db = _db_mock(delete_user_link=AsyncMock(side_effect=NotFoundError("missing")))
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await delete_user_link_route(link_id="x", user_id="u1")
|
||||
assert exc.value.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_owned_maps_to_403(self):
|
||||
from backend.api.features.platform_linking.routes import delete_user_link_route
|
||||
|
||||
db = _db_mock(
|
||||
delete_user_link=AsyncMock(side_effect=NotAuthorizedError("nope"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await delete_user_link_route(link_id="x", user_id="u1")
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
# ── Adversarial: malformed token path params ──────────────────────────
|
||||
|
||||
|
||||
class TestAdversarialTokenPath:
|
||||
# TokenPath enforces `^[A-Za-z0-9_-]+$` + max_length=64.
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
import fastapi
|
||||
from autogpt_libs.auth import get_user_id, requires_user
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import backend.api.features.platform_linking.routes as routes_mod
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.dependency_overrides[requires_user] = lambda: None
|
||||
app.dependency_overrides[get_user_id] = lambda: "caller-user"
|
||||
app.include_router(routes_mod.router, prefix="/api/platform-linking")
|
||||
return TestClient(app)
|
||||
|
||||
def test_rejects_token_with_special_chars(self, client):
|
||||
response = client.get("/api/platform-linking/tokens/bad%24token/info")
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_rejects_token_with_path_traversal(self, client):
|
||||
for probe in ("..%2F..", "foo..bar", "foo%2Fbar"):
|
||||
response = client.get(f"/api/platform-linking/tokens/{probe}/info")
|
||||
assert response.status_code in (
|
||||
404,
|
||||
422,
|
||||
), f"path-traversal probe {probe!r} returned {response.status_code}"
|
||||
|
||||
def test_rejects_token_too_long(self, client):
|
||||
long_token = "a" * 65
|
||||
response = client.get(f"/api/platform-linking/tokens/{long_token}/info")
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_accepts_token_at_max_length(self, client):
|
||||
token = "a" * 64
|
||||
db = _db_mock(
|
||||
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
response = client.get(f"/api/platform-linking/tokens/{token}/info")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_accepts_urlsafe_b64_token_shape(self, client):
|
||||
db = _db_mock(
|
||||
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
response = client.get("/api/platform-linking/tokens/abc-_XYZ123-_abc/info")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_confirm_rejects_malformed_token(self, client):
|
||||
response = client.post("/api/platform-linking/tokens/bad%24token/confirm")
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestAdversarialDeleteLinkId:
|
||||
"""DELETE link_id has no regex — ensure weird values are handled via
|
||||
NotFoundError (no crash, no cross-user leak)."""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
import fastapi
|
||||
from autogpt_libs.auth import get_user_id, requires_user
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import backend.api.features.platform_linking.routes as routes_mod
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.dependency_overrides[requires_user] = lambda: None
|
||||
app.dependency_overrides[get_user_id] = lambda: "caller-user"
|
||||
app.include_router(routes_mod.router, prefix="/api/platform-linking")
|
||||
return TestClient(app)
|
||||
|
||||
def test_weird_link_id_returns_404(self, client):
|
||||
db = _db_mock(
|
||||
delete_server_link=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
for link_id in ("'; DROP TABLE links;--", "../../etc/passwd", ""):
|
||||
response = client.delete(f"/api/platform-linking/links/{link_id}")
|
||||
assert response.status_code in (404, 405)
|
||||
@@ -1,20 +0,0 @@
|
||||
import pydantic
|
||||
|
||||
|
||||
class PushSubscriptionKeys(pydantic.BaseModel):
|
||||
p256dh: str = pydantic.Field(min_length=1, max_length=512)
|
||||
auth: str = pydantic.Field(min_length=1, max_length=512)
|
||||
|
||||
|
||||
class PushSubscribeRequest(pydantic.BaseModel):
|
||||
endpoint: str = pydantic.Field(min_length=1, max_length=2048)
|
||||
keys: PushSubscriptionKeys
|
||||
user_agent: str | None = pydantic.Field(default=None, max_length=512)
|
||||
|
||||
|
||||
class PushUnsubscribeRequest(pydantic.BaseModel):
|
||||
endpoint: str = pydantic.Field(min_length=1, max_length=2048)
|
||||
|
||||
|
||||
class VapidPublicKeyResponse(pydantic.BaseModel):
|
||||
public_key: str
|
||||
@@ -1,64 +0,0 @@
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_user
|
||||
from fastapi import APIRouter, HTTPException, Security
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
|
||||
|
||||
from backend.api.features.push.model import (
|
||||
PushSubscribeRequest,
|
||||
PushUnsubscribeRequest,
|
||||
VapidPublicKeyResponse,
|
||||
)
|
||||
from backend.data.push_subscription import (
|
||||
delete_push_subscription,
|
||||
upsert_push_subscription,
|
||||
validate_push_endpoint,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
router = APIRouter()
|
||||
_settings = Settings()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/vapid-key",
|
||||
summary="Get VAPID public key for push subscription",
|
||||
)
|
||||
async def get_vapid_public_key() -> VapidPublicKeyResponse:
|
||||
return VapidPublicKeyResponse(public_key=_settings.secrets.vapid_public_key)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/subscribe",
|
||||
summary="Register a push subscription for the current user",
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def subscribe_push(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
body: PushSubscribeRequest,
|
||||
) -> None:
|
||||
try:
|
||||
await validate_push_endpoint(body.endpoint)
|
||||
await upsert_push_subscription(
|
||||
user_id=user_id,
|
||||
endpoint=body.endpoint,
|
||||
p256dh=body.keys.p256dh,
|
||||
auth=body.keys.auth,
|
||||
user_agent=body.user_agent,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/unsubscribe",
|
||||
summary="Remove a push subscription",
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def unsubscribe_push(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
body: PushUnsubscribeRequest,
|
||||
) -> None:
|
||||
await delete_push_subscription(user_id, body.endpoint)
|
||||
@@ -1,240 +0,0 @@
|
||||
"""Tests for push notification routes."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
|
||||
from backend.api.features.push.routes import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_vapid_public_key(mocker):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.vapid_public_key = "test-vapid-public-key-base64url"
|
||||
mocker.patch(
|
||||
"backend.api.features.push.routes._settings",
|
||||
mock_settings,
|
||||
)
|
||||
|
||||
response = client.get("/vapid-key")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["public_key"] == "test-vapid-public-key-base64url"
|
||||
|
||||
|
||||
def test_get_vapid_public_key_empty(mocker):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.vapid_public_key = ""
|
||||
mocker.patch(
|
||||
"backend.api.features.push.routes._settings",
|
||||
mock_settings,
|
||||
)
|
||||
|
||||
response = client.get("/vapid-key")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["public_key"] == ""
|
||||
|
||||
|
||||
def test_subscribe_push(mocker, test_user_id):
|
||||
mock_upsert = mocker.patch(
|
||||
"backend.api.features.push.routes.upsert_push_subscription",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||
"keys": {
|
||||
"p256dh": "test-p256dh-key",
|
||||
"auth": "test-auth-key",
|
||||
},
|
||||
"user_agent": "Mozilla/5.0 Test",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
mock_upsert.assert_awaited_once_with(
|
||||
user_id=test_user_id,
|
||||
endpoint="https://fcm.googleapis.com/fcm/send/abc123",
|
||||
p256dh="test-p256dh-key",
|
||||
auth="test-auth-key",
|
||||
user_agent="Mozilla/5.0 Test",
|
||||
)
|
||||
|
||||
|
||||
def test_subscribe_push_without_user_agent(mocker, test_user_id):
|
||||
mock_upsert = mocker.patch(
|
||||
"backend.api.features.push.routes.upsert_push_subscription",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||
"keys": {
|
||||
"p256dh": "test-p256dh-key",
|
||||
"auth": "test-auth-key",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
mock_upsert.assert_awaited_once_with(
|
||||
user_id=test_user_id,
|
||||
endpoint="https://fcm.googleapis.com/fcm/send/abc123",
|
||||
p256dh="test-p256dh-key",
|
||||
auth="test-auth-key",
|
||||
user_agent=None,
|
||||
)
|
||||
|
||||
|
||||
def test_subscribe_push_missing_keys():
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_subscribe_push_missing_endpoint():
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"keys": {
|
||||
"p256dh": "test-p256dh-key",
|
||||
"auth": "test-auth-key",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_subscribe_push_rejects_empty_crypto_keys():
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||
"keys": {"p256dh": "", "auth": ""},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_subscribe_push_rejects_oversized_endpoint():
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/" + "x" * 3000,
|
||||
"keys": {"p256dh": "k", "auth": "a"},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_unsubscribe_push(mocker, test_user_id):
|
||||
mock_delete = mocker.patch(
|
||||
"backend.api.features.push.routes.delete_push_subscription",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/unsubscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
mock_delete.assert_awaited_once_with(
|
||||
test_user_id,
|
||||
"https://fcm.googleapis.com/fcm/send/abc123",
|
||||
)
|
||||
|
||||
|
||||
def test_unsubscribe_push_missing_endpoint():
|
||||
response = client.post(
|
||||
"/unsubscribe",
|
||||
json={},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"untrusted_endpoint",
|
||||
[
|
||||
"https://localhost/evil",
|
||||
"https://127.0.0.1/evil",
|
||||
"https://169.254.169.254/latest/meta-data/",
|
||||
"https://internal-service.local/api",
|
||||
"https://attacker.example.com/push",
|
||||
"http://fcm.googleapis.com/fcm/send/abc",
|
||||
"file:///etc/passwd",
|
||||
],
|
||||
)
|
||||
def test_subscribe_push_rejects_untrusted_endpoints(mocker, untrusted_endpoint):
|
||||
mock_upsert = mocker.patch(
|
||||
"backend.api.features.push.routes.upsert_push_subscription",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": untrusted_endpoint,
|
||||
"keys": {
|
||||
"p256dh": "test-p256dh-key",
|
||||
"auth": "test-auth-key",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
mock_upsert.assert_not_awaited()
|
||||
|
||||
|
||||
def test_subscribe_push_surfaces_cap_as_400(mocker):
|
||||
mocker.patch(
|
||||
"backend.api.features.push.routes.upsert_push_subscription",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("Subscription limit of 20 per user reached"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||
"keys": {
|
||||
"p256dh": "test-p256dh-key",
|
||||
"auth": "test-auth-key",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Subscription limit" in response.json()["detail"]
|
||||
@@ -490,9 +490,6 @@ async def get_store_creators(
|
||||
# Build where clause with sanitized inputs
|
||||
where = {}
|
||||
|
||||
# Only return creators with approved agents
|
||||
where["num_agents"] = {"gt": 0}
|
||||
|
||||
if featured:
|
||||
where["is_featured"] = featured
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import prisma.enums
|
||||
import prisma.errors
|
||||
@@ -51,8 +50,8 @@ async def test_get_store_agents(mocker):
|
||||
|
||||
# Mock prisma calls
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
mock_store_agent.return_value.find_many = AsyncMock(return_value=mock_agents)
|
||||
mock_store_agent.return_value.count = AsyncMock(return_value=1)
|
||||
mock_store_agent.return_value.find_many = mocker.AsyncMock(return_value=mock_agents)
|
||||
mock_store_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_agents()
|
||||
@@ -95,7 +94,7 @@ async def test_get_store_agent_details(mocker):
|
||||
|
||||
# Mock StoreAgent prisma call
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
mock_store_agent.return_value.find_first = AsyncMock(return_value=mock_agent)
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_agent_details("creator", "test-agent")
|
||||
@@ -134,7 +133,7 @@ async def test_get_store_creator(mocker):
|
||||
|
||||
# Mock prisma call
|
||||
mock_creator = mocker.patch("prisma.models.Creator.prisma")
|
||||
mock_creator.return_value.find_unique = AsyncMock()
|
||||
mock_creator.return_value.find_unique = mocker.AsyncMock()
|
||||
# Configure the mock to return values that will pass validation
|
||||
mock_creator.return_value.find_unique.return_value = mock_creator_data
|
||||
|
||||
@@ -190,7 +189,7 @@ async def test_create_store_submission(mocker):
|
||||
notifyOnAgentApproved=True,
|
||||
notifyOnAgentRejected=True,
|
||||
timezone="Europe/Delft",
|
||||
subscriptionTier=prisma.enums.SubscriptionTier.BASIC, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
|
||||
subscriptionTier=prisma.enums.SubscriptionTier.FREE, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
|
||||
)
|
||||
mock_agent = prisma.models.AgentGraph(
|
||||
id="agent-id",
|
||||
@@ -237,23 +236,23 @@ async def test_create_store_submission(mocker):
|
||||
|
||||
# Mock prisma calls
|
||||
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
||||
mock_agent_graph.return_value.find_first = AsyncMock(return_value=mock_agent)
|
||||
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Mock transaction context manager
|
||||
mock_tx = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"backend.api.features.store.db.transaction",
|
||||
return_value=AsyncMock(
|
||||
__aenter__=AsyncMock(return_value=mock_tx),
|
||||
__aexit__=AsyncMock(return_value=False),
|
||||
return_value=mocker.AsyncMock(
|
||||
__aenter__=mocker.AsyncMock(return_value=mock_tx),
|
||||
__aexit__=mocker.AsyncMock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
mock_sl = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_sl.return_value.find_unique = AsyncMock(return_value=None)
|
||||
mock_sl.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||
|
||||
mock_slv = mocker.patch("prisma.models.StoreListingVersion.prisma")
|
||||
mock_slv.return_value.create = AsyncMock(return_value=mock_version)
|
||||
mock_slv.return_value.create = mocker.AsyncMock(return_value=mock_version)
|
||||
|
||||
# Call function
|
||||
result = await db.create_store_submission(
|
||||
@@ -293,8 +292,10 @@ async def test_update_profile(mocker):
|
||||
|
||||
# Mock prisma calls
|
||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||
mock_profile_db.return_value.find_first = AsyncMock(return_value=mock_profile)
|
||||
mock_profile_db.return_value.update = AsyncMock(return_value=mock_profile)
|
||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_profile
|
||||
)
|
||||
mock_profile_db.return_value.update = mocker.AsyncMock(return_value=mock_profile)
|
||||
|
||||
# Test data
|
||||
profile = Profile(
|
||||
@@ -335,7 +336,9 @@ async def test_get_user_profile(mocker):
|
||||
|
||||
# Mock prisma calls
|
||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||
mock_profile_db.return_value.find_first = AsyncMock(return_value=mock_profile)
|
||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_profile
|
||||
)
|
||||
|
||||
# Call function
|
||||
result = await db.get_user_profile("user-id")
|
||||
@@ -393,38 +396,3 @@ async def test_get_store_agents_search_category_array_injection():
|
||||
# Verify the query executed without error
|
||||
# Category should be parameterized, preventing SQL injection
|
||||
assert isinstance(result.agents, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_creators_only_returns_approved(mocker):
|
||||
mock_creators = [
|
||||
prisma.models.Creator(
|
||||
name="Creator One",
|
||||
username="creator1",
|
||||
description="desc",
|
||||
links=["link1"],
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=1,
|
||||
agent_rating=4.5,
|
||||
agent_runs=10,
|
||||
top_categories=["test"],
|
||||
is_featured=False,
|
||||
)
|
||||
]
|
||||
|
||||
mock_creator = mocker.patch("prisma.models.Creator.prisma")
|
||||
mock_creator.return_value.find_many = AsyncMock(return_value=mock_creators)
|
||||
mock_creator.return_value.count = AsyncMock(return_value=1)
|
||||
|
||||
result = await db.get_store_creators()
|
||||
|
||||
assert len(result.creators) == 1
|
||||
assert result.creators[0].username == "creator1"
|
||||
|
||||
mock_creator.return_value.find_many.assert_called_once()
|
||||
mock_creator.return_value.count.assert_called_once()
|
||||
|
||||
_, find_kwargs = mock_creator.return_value.find_many.call_args
|
||||
_, count_kwargs = mock_creator.return_value.count.call_args
|
||||
assert find_kwargs["where"]["num_agents"] == {"gt": 0}
|
||||
assert count_kwargs["where"]["num_agents"] == {"gt": 0}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,11 +26,10 @@ from fastapi import (
|
||||
)
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from prisma.enums import SubscriptionTier
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
from backend.api.model import (
|
||||
CreateAPIKeyRequest,
|
||||
CreateAPIKeyResponse,
|
||||
@@ -44,33 +43,26 @@ from backend.api.model import (
|
||||
UploadFileResponse,
|
||||
)
|
||||
from backend.blocks import get_block, get_blocks
|
||||
from backend.copilot.rate_limit import get_tier_multipliers
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.auth import api_key as api_key_db
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
PendingChangeUnknown,
|
||||
RefundRequest,
|
||||
TransactionHistory,
|
||||
UserCredit,
|
||||
cancel_stripe_subscription,
|
||||
create_subscription_checkout,
|
||||
get_active_subscription_period_end,
|
||||
get_auto_top_up,
|
||||
get_pending_subscription_change,
|
||||
get_proration_credit_cents,
|
||||
get_subscription_price_id,
|
||||
get_user_credit_model,
|
||||
handle_subscription_payment_failure,
|
||||
handle_subscription_payment_success,
|
||||
modify_stripe_subscription_for_tier,
|
||||
release_pending_subscription_schedule,
|
||||
set_auto_top_up,
|
||||
set_subscription_tier,
|
||||
sync_subscription_from_stripe,
|
||||
sync_subscription_schedule_from_stripe,
|
||||
)
|
||||
from backend.data.graph import GraphSettings
|
||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||
@@ -100,7 +92,6 @@ from backend.data.user import (
|
||||
update_user_notification_preference,
|
||||
update_user_timezone,
|
||||
)
|
||||
from backend.data.workspace import get_workspace_file_by_id
|
||||
from backend.executor import scheduler
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||
@@ -702,51 +693,20 @@ async def get_user_auto_top_up(
|
||||
|
||||
|
||||
class SubscriptionTierRequest(BaseModel):
|
||||
tier: Literal["NO_TIER", "BASIC", "PRO", "MAX", "BUSINESS"]
|
||||
tier: Literal["FREE", "PRO", "BUSINESS"]
|
||||
success_url: str = ""
|
||||
cancel_url: str = ""
|
||||
|
||||
|
||||
class SubscriptionCheckoutResponse(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
tier: Literal["NO_TIER", "BASIC", "PRO", "MAX", "BUSINESS", "ENTERPRISE"]
|
||||
tier: Literal["FREE", "PRO", "BUSINESS", "ENTERPRISE"]
|
||||
monthly_cost: int # amount in cents (Stripe convention)
|
||||
tier_costs: dict[str, int] # tier name -> amount in cents
|
||||
tier_multipliers: dict[str, float] = Field(
|
||||
default_factory=dict,
|
||||
description=(
|
||||
"Tier → rate-limit multiplier. Covers the same tiers listed in"
|
||||
" ``tier_costs`` so the frontend can render rate-limit badges"
|
||||
" relative to the lowest visible tier without knowing backend"
|
||||
" defaults."
|
||||
),
|
||||
)
|
||||
proration_credit_cents: int # unused portion of current sub to convert on upgrade
|
||||
has_active_stripe_subscription: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"True when the user has an active/trialing Stripe subscription. The"
|
||||
" frontend uses this to branch upgrade UX: modify-in-place + saved-card"
|
||||
" auto-charge when True, redirect to Stripe Checkout when False."
|
||||
),
|
||||
)
|
||||
current_period_end: Optional[int] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Unix timestamp of the active subscription's current_period_end. Used"
|
||||
" to show the date Stripe will issue the next invoice (with prorated"
|
||||
" upgrade charges, if any). None when no active sub."
|
||||
),
|
||||
)
|
||||
pending_tier: Optional[Literal["NO_TIER", "BASIC", "PRO", "MAX", "BUSINESS"]] = None
|
||||
pending_tier_effective_at: Optional[datetime] = None
|
||||
url: str = Field(
|
||||
default="",
|
||||
description=(
|
||||
"Populated only when POST /credits/subscription starts a Stripe Checkout"
|
||||
" Session (BASIC → paid upgrade). Empty string in all other branches —"
|
||||
" the client redirects to this URL when non-empty."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _validate_checkout_redirect_url(url: str) -> bool:
|
||||
@@ -822,89 +782,39 @@ async def get_subscription_status(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> SubscriptionStatusResponse:
|
||||
user = await get_user_by_id(user_id)
|
||||
tier = user.subscription_tier or SubscriptionTier.NO_TIER
|
||||
tier = user.subscription_tier or SubscriptionTier.FREE
|
||||
|
||||
# Tiers that *can* have a Stripe price configured (and therefore appear
|
||||
# in the tier picker if the LD flag exposes a price-id). NO_TIER is not
|
||||
# priceable — it's the implicit "no active subscription" state.
|
||||
priceable_tiers = [
|
||||
SubscriptionTier.BASIC,
|
||||
SubscriptionTier.PRO,
|
||||
SubscriptionTier.MAX,
|
||||
SubscriptionTier.BUSINESS,
|
||||
]
|
||||
paid_tiers = [SubscriptionTier.PRO, SubscriptionTier.BUSINESS]
|
||||
price_ids = await asyncio.gather(
|
||||
*[get_subscription_price_id(t) for t in priceable_tiers]
|
||||
*[get_subscription_price_id(t) for t in paid_tiers]
|
||||
)
|
||||
|
||||
tier_costs: dict[str, int] = {
|
||||
SubscriptionTier.FREE.value: 0,
|
||||
SubscriptionTier.ENTERPRISE.value: 0,
|
||||
}
|
||||
|
||||
async def _cost(pid: str | None) -> int:
|
||||
return (await _get_stripe_price_amount(pid) or 0) if pid else 0
|
||||
|
||||
costs = await asyncio.gather(*[_cost(pid) for pid in price_ids])
|
||||
|
||||
tier_costs: dict[str, int] = {}
|
||||
for t, pid, cost in zip(priceable_tiers, price_ids, costs):
|
||||
if pid:
|
||||
tier_costs[t.value] = cost
|
||||
|
||||
# Expose the effective rate-limit multipliers alongside prices so the
|
||||
# frontend can render "Nx rate limits" relative to the lowest visible
|
||||
# tier without hard-coding backend defaults. Only emit entries for tiers
|
||||
# that land in ``tier_costs`` — rows hidden at the price layer must stay
|
||||
# hidden in the multiplier layer too.
|
||||
multipliers = await get_tier_multipliers()
|
||||
tier_multipliers: dict[str, float] = {
|
||||
t.value: multipliers.get(t, 1.0)
|
||||
for t in priceable_tiers
|
||||
if t.value in tier_costs
|
||||
}
|
||||
for t, cost in zip(paid_tiers, costs):
|
||||
tier_costs[t.value] = cost
|
||||
|
||||
current_monthly_cost = tier_costs.get(tier.value, 0)
|
||||
proration_credit, current_period_end = await asyncio.gather(
|
||||
get_proration_credit_cents(user_id, current_monthly_cost),
|
||||
get_active_subscription_period_end(user_id),
|
||||
)
|
||||
proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost)
|
||||
|
||||
try:
|
||||
pending = await get_pending_subscription_change(user_id)
|
||||
except (stripe.StripeError, PendingChangeUnknown):
|
||||
# Swallow Stripe-side failures (rate limits, transient network) AND
|
||||
# PendingChangeUnknown (LaunchDarkly price-id lookup failed). Both
|
||||
# propagate past the cache so the next request retries fresh instead
|
||||
# of serving a stale None for the TTL window. Let real bugs (KeyError,
|
||||
# AttributeError, etc.) propagate so they surface in Sentry.
|
||||
logger.exception(
|
||||
"get_subscription_status: failed to resolve pending change for user %s",
|
||||
user_id,
|
||||
)
|
||||
pending = None
|
||||
|
||||
response = SubscriptionStatusResponse(
|
||||
return SubscriptionStatusResponse(
|
||||
tier=tier.value,
|
||||
monthly_cost=current_monthly_cost,
|
||||
tier_costs=tier_costs,
|
||||
tier_multipliers=tier_multipliers,
|
||||
proration_credit_cents=proration_credit,
|
||||
has_active_stripe_subscription=current_period_end is not None,
|
||||
current_period_end=current_period_end,
|
||||
)
|
||||
if pending is not None:
|
||||
pending_tier_enum, pending_effective_at = pending
|
||||
if pending_tier_enum in (
|
||||
SubscriptionTier.NO_TIER,
|
||||
SubscriptionTier.BASIC,
|
||||
SubscriptionTier.PRO,
|
||||
SubscriptionTier.MAX,
|
||||
SubscriptionTier.BUSINESS,
|
||||
):
|
||||
response.pending_tier = pending_tier_enum.value
|
||||
response.pending_tier_effective_at = pending_effective_at
|
||||
return response
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/credits/subscription",
|
||||
summary="Update subscription tier or start a Stripe Checkout session",
|
||||
summary="Start a Stripe Checkout session to upgrade subscription tier",
|
||||
operation_id="updateSubscriptionTier",
|
||||
tags=["credits"],
|
||||
dependencies=[Security(requires_user)],
|
||||
@@ -912,59 +822,38 @@ async def get_subscription_status(
|
||||
async def update_subscription_tier(
|
||||
request: SubscriptionTierRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> SubscriptionStatusResponse:
|
||||
# Pydantic validates tier is one of BASIC/PRO/MAX/BUSINESS via Literal type.
|
||||
) -> SubscriptionCheckoutResponse:
|
||||
# Pydantic validates tier is one of FREE/PRO/BUSINESS via Literal type.
|
||||
tier = SubscriptionTier(request.tier)
|
||||
|
||||
# ENTERPRISE tier is admin-managed — block self-service changes from ENTERPRISE users.
|
||||
user = await get_user_by_id(user_id)
|
||||
if (
|
||||
user.subscription_tier or SubscriptionTier.NO_TIER
|
||||
) == SubscriptionTier.ENTERPRISE:
|
||||
if (user.subscription_tier or SubscriptionTier.FREE) == SubscriptionTier.ENTERPRISE:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="ENTERPRISE subscription changes must be managed by an administrator",
|
||||
)
|
||||
|
||||
# Same-tier request = "stay on my current tier" = cancel any pending
|
||||
# scheduled change (paid→paid downgrade or paid→BASIC cancel). This is the
|
||||
# collapsed behaviour that replaces the old /credits/subscription/cancel-pending
|
||||
# route. Safe when no pending change exists: release_pending_subscription_schedule
|
||||
# returns False and we simply return the current status.
|
||||
if (user.subscription_tier or SubscriptionTier.NO_TIER) == tier:
|
||||
try:
|
||||
await release_pending_subscription_schedule(user_id)
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error releasing pending subscription change for user %s: %s",
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to cancel the pending subscription change right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
return await get_subscription_status(user_id)
|
||||
|
||||
payment_enabled = await is_feature_enabled(
|
||||
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
|
||||
)
|
||||
|
||||
target_price_id = await get_subscription_price_id(tier)
|
||||
|
||||
# Cancel: target NO_TIER. Schedule Stripe cancellation at period end;
|
||||
# cancel_at_period_end=True lets the webhook flip the DB tier. No active
|
||||
# sub (admin-granted or never-paid) or payment disabled → DB flip.
|
||||
# NO_TIER is never priceable, so this branch always fires for cancel
|
||||
# requests regardless of LD config.
|
||||
if tier == SubscriptionTier.NO_TIER:
|
||||
# Downgrade to FREE: schedule Stripe cancellation at period end so the user
|
||||
# keeps their tier for the time they already paid for. The DB tier is NOT
|
||||
# updated here when a subscription exists — the customer.subscription.deleted
|
||||
# webhook fires at period end and downgrades to FREE then.
|
||||
# Exception: if the user has no active Stripe subscription (e.g. admin-granted
|
||||
# tier), cancel_stripe_subscription returns False and we update the DB tier
|
||||
# immediately since no webhook will ever fire.
|
||||
# When payment is disabled entirely, update the DB tier directly.
|
||||
if tier == SubscriptionTier.FREE:
|
||||
if payment_enabled:
|
||||
try:
|
||||
had_subscription = await cancel_stripe_subscription(user_id)
|
||||
except stripe.StripeError as e:
|
||||
# Log full Stripe error server-side but return a generic message
|
||||
# to the client — raw Stripe errors can leak customer/sub IDs and
|
||||
# infrastructure config details.
|
||||
logger.exception(
|
||||
"Stripe error cancelling subscription for user %s: %s",
|
||||
user_id,
|
||||
@@ -978,73 +867,63 @@ async def update_subscription_tier(
|
||||
),
|
||||
)
|
||||
if not had_subscription:
|
||||
# No active Stripe subscription found — the user was on an
|
||||
# admin-granted tier. Update DB immediately since the
|
||||
# subscription.deleted webhook will never fire.
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return await get_subscription_status(user_id)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return await get_subscription_status(user_id)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
|
||||
# Paid tier changes require payment to be enabled — block self-service upgrades
|
||||
# when the flag is off. Admins use the /api/admin/ routes to set tiers directly.
|
||||
if not payment_enabled:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Subscription not available for tier {tier.value}",
|
||||
detail=f"Subscription not available for tier {tier}",
|
||||
)
|
||||
|
||||
# Target has no LD price — not provisionable (matches the GET hiding).
|
||||
if target_price_id is None:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Subscription not available for tier {tier.value}",
|
||||
)
|
||||
# No-op short-circuit: if the user is already on the requested paid tier,
|
||||
# do NOT create a new Checkout Session. Without this guard, a duplicate
|
||||
# request (double-click, retried POST, stale page) creates a second
|
||||
# subscription for the same price; the user would be charged for both
|
||||
# until `_cleanup_stale_subscriptions` runs from the resulting webhook —
|
||||
# which only fires after the second charge has cleared.
|
||||
if (user.subscription_tier or SubscriptionTier.FREE) == tier:
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
|
||||
# Modify in place if there's a sub; else fall through to Checkout below.
|
||||
try:
|
||||
modified = await modify_stripe_subscription_for_tier(user_id, tier)
|
||||
if modified:
|
||||
return await get_subscription_status(user_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except stripe.InvalidRequestError as e:
|
||||
# Stripe rejects schedule modify when phases mix currencies, e.g. the
|
||||
# active sub was checked out in GBP but the target tier's Price is
|
||||
# USD-only. 502 reads as outage; surface a 422 with a specific message
|
||||
# so the user/admin can see what to fix in Stripe.
|
||||
msg = str(e)
|
||||
if "currency" in msg.lower():
|
||||
logger.warning(
|
||||
"Currency mismatch on tier change for user %s: %s", user_id, msg
|
||||
# Paid→paid tier change: if the user already has a Stripe subscription,
|
||||
# modify it in-place with proration instead of creating a new Checkout
|
||||
# Session. This preserves remaining paid time and avoids double-charging.
|
||||
# The customer.subscription.updated webhook fires and updates the DB tier.
|
||||
current_tier = user.subscription_tier or SubscriptionTier.FREE
|
||||
if current_tier in (SubscriptionTier.PRO, SubscriptionTier.BUSINESS):
|
||||
try:
|
||||
modified = await modify_stripe_subscription_for_tier(user_id, tier)
|
||||
if modified:
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
# modify_stripe_subscription_for_tier returns False when no active
|
||||
# Stripe subscription exists — i.e. the user has an admin-granted
|
||||
# paid tier with no Stripe record. In that case, update the DB
|
||||
# tier directly (same as the FREE-downgrade path for admin-granted
|
||||
# users) rather than sending them through a new Checkout Session.
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error modifying subscription for user %s: %s", user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Tier change unavailable for your current billing currency."
|
||||
" Please contact support — the target tier needs to be"
|
||||
" configured for your currency in Stripe before this"
|
||||
" change can go through."
|
||||
"Unable to update your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
logger.exception(
|
||||
"Stripe error modifying subscription for user %s: %s", user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to update your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error modifying subscription for user %s: %s", user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to update your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
|
||||
# No active Stripe subscription → create Stripe Checkout Session.
|
||||
# Paid upgrade from FREE → create Stripe Checkout Session.
|
||||
if not request.success_url or not request.cancel_url:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
@@ -1099,9 +978,7 @@ async def update_subscription_tier(
|
||||
),
|
||||
)
|
||||
|
||||
status = await get_subscription_status(user_id)
|
||||
status.url = url
|
||||
return status
|
||||
return SubscriptionCheckoutResponse(url=url)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -1166,21 +1043,6 @@ async def stripe_webhook(request: Request):
|
||||
):
|
||||
await sync_subscription_from_stripe(data_object)
|
||||
|
||||
# `subscription_schedule.updated` is deliberately omitted: our own
|
||||
# `SubscriptionSchedule.create` + `.modify` calls in
|
||||
# `_schedule_downgrade_at_period_end` would fire that event right back at us
|
||||
# and loop redundant traffic through this handler. We only care about state
|
||||
# transitions (released / completed); phase advance to the new price is
|
||||
# already covered by `customer.subscription.updated`.
|
||||
if event_type in (
|
||||
"subscription_schedule.released",
|
||||
"subscription_schedule.completed",
|
||||
):
|
||||
await sync_subscription_schedule_from_stripe(data_object)
|
||||
|
||||
if event_type == "invoice.payment_succeeded":
|
||||
await handle_subscription_payment_success(data_object)
|
||||
|
||||
if event_type == "invoice.payment_failed":
|
||||
await handle_subscription_payment_failure(data_object)
|
||||
|
||||
@@ -1778,10 +1640,6 @@ async def enable_execution_sharing(
|
||||
# Generate a unique share token
|
||||
share_token = str(uuid.uuid4())
|
||||
|
||||
# Remove stale allowlist records before updating the token — prevents a
|
||||
# window where old records + new token could coexist.
|
||||
await execution_db.delete_shared_execution_files(execution_id=graph_exec_id)
|
||||
|
||||
# Update the execution with share info
|
||||
await execution_db.update_graph_execution_share_status(
|
||||
execution_id=graph_exec_id,
|
||||
@@ -1791,14 +1649,6 @@ async def enable_execution_sharing(
|
||||
shared_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Create allowlist of workspace files referenced in outputs
|
||||
await execution_db.create_shared_execution_files(
|
||||
execution_id=graph_exec_id,
|
||||
share_token=share_token,
|
||||
user_id=user_id,
|
||||
outputs=execution.outputs,
|
||||
)
|
||||
|
||||
# Return the share URL
|
||||
frontend_url = settings.config.frontend_base_url or "http://localhost:3000"
|
||||
share_url = f"{frontend_url}/share/{share_token}"
|
||||
@@ -1824,9 +1674,6 @@ async def disable_execution_sharing(
|
||||
if not execution:
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
# Remove shared file allowlist records
|
||||
await execution_db.delete_shared_execution_files(execution_id=graph_exec_id)
|
||||
|
||||
# Remove share info
|
||||
await execution_db.update_graph_execution_share_status(
|
||||
execution_id=graph_exec_id,
|
||||
@@ -1852,43 +1699,6 @@ async def get_shared_execution(
|
||||
return execution
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/public/shared/{share_token}/files/{file_id}/download",
|
||||
summary="Download a file from a shared execution",
|
||||
operation_id="download_shared_file",
|
||||
tags=["graphs"],
|
||||
)
|
||||
async def download_shared_file(
|
||||
share_token: Annotated[
|
||||
str,
|
||||
Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"),
|
||||
],
|
||||
file_id: Annotated[
|
||||
str,
|
||||
Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"),
|
||||
],
|
||||
) -> Response:
|
||||
"""Download a workspace file from a shared execution (no auth required).
|
||||
|
||||
Validates that the file was explicitly exposed when sharing was enabled.
|
||||
Returns a uniform 404 for all failure modes to prevent enumeration attacks.
|
||||
"""
|
||||
# Single-query validation against the allowlist
|
||||
execution_id = await execution_db.get_shared_execution_file(
|
||||
share_token=share_token, file_id=file_id
|
||||
)
|
||||
if not execution_id:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
|
||||
# Look up the actual file (no workspace scoping needed — the allowlist
|
||||
# already validated that this file belongs to the shared execution)
|
||||
file = await get_workspace_file_by_id(file_id)
|
||||
if not file:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
|
||||
return await create_file_download_response(file, inline=True)
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Schedules ########################
|
||||
########################################################
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
"""Tests for the public shared file download endpoint."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.responses import Response
|
||||
|
||||
from backend.api.features.v1 import v1_router
|
||||
from backend.data.workspace import WorkspaceFile
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(v1_router, prefix="/api")
|
||||
|
||||
VALID_TOKEN = "550e8400-e29b-41d4-a716-446655440000"
|
||||
VALID_FILE_ID = "6ba7b810-9dad-11d1-80b4-00c04fd430c8"
|
||||
|
||||
|
||||
def _make_workspace_file(**overrides) -> WorkspaceFile:
|
||||
defaults = {
|
||||
"id": VALID_FILE_ID,
|
||||
"workspace_id": "ws-001",
|
||||
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"name": "image.png",
|
||||
"path": "/image.png",
|
||||
"storage_path": "local://uploads/image.png",
|
||||
"mime_type": "image/png",
|
||||
"size_bytes": 4,
|
||||
"checksum": None,
|
||||
"is_deleted": False,
|
||||
"deleted_at": None,
|
||||
"metadata": {},
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return WorkspaceFile(**defaults)
|
||||
|
||||
|
||||
def _mock_download_response(**kwargs):
|
||||
"""Return an AsyncMock that resolves to a Response with inline disposition."""
|
||||
|
||||
async def _handler(file, *, inline=False):
|
||||
return Response(
|
||||
content=b"\x89PNG",
|
||||
media_type="image/png",
|
||||
headers={
|
||||
"Content-Disposition": (
|
||||
'inline; filename="image.png"'
|
||||
if inline
|
||||
else 'attachment; filename="image.png"'
|
||||
),
|
||||
"Content-Length": "4",
|
||||
},
|
||||
)
|
||||
|
||||
return _handler
|
||||
|
||||
|
||||
class TestDownloadSharedFile:
|
||||
"""Tests for GET /api/public/shared/{token}/files/{id}/download."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _client(self):
|
||||
self.client = TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
def test_valid_token_and_file_returns_inline_content(self):
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value="exec-123",
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.v1.get_workspace_file_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_make_workspace_file(),
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.v1.create_file_download_response",
|
||||
side_effect=_mock_download_response(),
|
||||
),
|
||||
):
|
||||
response = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"\x89PNG"
|
||||
assert "inline" in response.headers["Content-Disposition"]
|
||||
|
||||
def test_invalid_token_format_returns_422(self):
|
||||
response = self.client.get(
|
||||
f"/api/public/shared/not-a-uuid/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_token_not_in_allowlist_returns_404(self):
|
||||
with patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
response = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_file_missing_from_workspace_returns_404(self):
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value="exec-123",
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.v1.get_workspace_file_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_uniform_404_prevents_enumeration(self):
|
||||
"""Both failure modes produce identical 404 — no information leak."""
|
||||
with patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
resp_no_allow = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value="exec-123",
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.v1.get_workspace_file_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
resp_no_file = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
|
||||
assert resp_no_allow.status_code == 404
|
||||
assert resp_no_file.status_code == 404
|
||||
assert resp_no_allow.json() == resp_no_file.json()
|
||||
@@ -29,9 +29,7 @@ from backend.util.workspace import WorkspaceManager
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
|
||||
def _sanitize_filename_for_header(
|
||||
filename: str, disposition: str = "attachment"
|
||||
) -> str:
|
||||
def _sanitize_filename_for_header(filename: str) -> str:
|
||||
"""
|
||||
Sanitize filename for Content-Disposition header to prevent header injection.
|
||||
|
||||
@@ -46,11 +44,11 @@ def _sanitize_filename_for_header(
|
||||
# Check if filename has non-ASCII characters
|
||||
try:
|
||||
sanitized.encode("ascii")
|
||||
return f'{disposition}; filename="{sanitized}"'
|
||||
return f'attachment; filename="{sanitized}"'
|
||||
except UnicodeEncodeError:
|
||||
# Use RFC5987 encoding for UTF-8 filenames
|
||||
encoded = quote(sanitized, safe="")
|
||||
return f"{disposition}; filename*=UTF-8''{encoded}"
|
||||
return f"attachment; filename*=UTF-8''{encoded}"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -60,26 +58,19 @@ router = fastapi.APIRouter(
|
||||
)
|
||||
|
||||
|
||||
def _create_streaming_response(
|
||||
content: bytes, file: WorkspaceFile, *, inline: bool = False
|
||||
) -> Response:
|
||||
def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
|
||||
"""Create a streaming response for file content."""
|
||||
disposition = _sanitize_filename_for_header(
|
||||
file.name, disposition="inline" if inline else "attachment"
|
||||
)
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=file.mime_type,
|
||||
headers={
|
||||
"Content-Disposition": disposition,
|
||||
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
||||
"Content-Length": str(len(content)),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def create_file_download_response(
|
||||
file: WorkspaceFile, *, inline: bool = False
|
||||
) -> Response:
|
||||
async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||
"""
|
||||
Create a download response for a workspace file.
|
||||
|
||||
@@ -91,7 +82,7 @@ async def create_file_download_response(
|
||||
# For local storage, stream the file directly
|
||||
if file.storage_path.startswith("local://"):
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
return _create_streaming_response(content, file, inline=inline)
|
||||
return _create_streaming_response(content, file)
|
||||
|
||||
# For GCS, try to redirect to signed URL, fall back to streaming
|
||||
try:
|
||||
@@ -99,7 +90,7 @@ async def create_file_download_response(
|
||||
# If we got back an API path (fallback), stream directly instead
|
||||
if url.startswith("/api/"):
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
return _create_streaming_response(content, file, inline=inline)
|
||||
return _create_streaming_response(content, file)
|
||||
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
||||
except Exception as e:
|
||||
# Log the signed URL failure with context
|
||||
@@ -111,7 +102,7 @@ async def create_file_download_response(
|
||||
# Fall back to streaming directly from GCS
|
||||
try:
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
return _create_streaming_response(content, file, inline=inline)
|
||||
return _create_streaming_response(content, file)
|
||||
except Exception as fallback_error:
|
||||
logger.error(
|
||||
f"Fallback streaming also failed for file {file.id} "
|
||||
@@ -178,7 +169,7 @@ async def download_file(
|
||||
if file is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return await create_file_download_response(file)
|
||||
return await _create_file_download_response(file)
|
||||
|
||||
|
||||
@router.delete(
|
||||
|
||||
@@ -600,221 +600,3 @@ def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace):
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=11, offset=50, include_all_sessions=True
|
||||
)
|
||||
|
||||
|
||||
# -- _sanitize_filename_for_header tests --
|
||||
|
||||
|
||||
class TestSanitizeFilenameForHeader:
|
||||
def test_simple_ascii_attachment(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
assert _sanitize_filename_for_header("report.pdf") == (
|
||||
'attachment; filename="report.pdf"'
|
||||
)
|
||||
|
||||
def test_inline_disposition(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
assert _sanitize_filename_for_header("image.png", disposition="inline") == (
|
||||
'inline; filename="image.png"'
|
||||
)
|
||||
|
||||
def test_strips_cr_lf_null(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("a\rb\nc\x00d.txt")
|
||||
assert "\r" not in result
|
||||
assert "\n" not in result
|
||||
assert "\x00" not in result
|
||||
assert 'filename="abcd.txt"' in result
|
||||
|
||||
def test_escapes_quotes(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header('file"name.txt')
|
||||
assert 'filename="file\\"name.txt"' in result
|
||||
|
||||
def test_header_injection_blocked(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("evil.txt\r\nX-Injected: true")
|
||||
# CR/LF stripped — the remaining text is safely inside the quoted value
|
||||
assert "\r" not in result
|
||||
assert "\n" not in result
|
||||
assert result == 'attachment; filename="evil.txtX-Injected: true"'
|
||||
|
||||
def test_unicode_uses_rfc5987(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("日本語.pdf")
|
||||
assert "filename*=UTF-8''" in result
|
||||
assert "attachment" in result
|
||||
|
||||
def test_unicode_inline(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("图片.png", disposition="inline")
|
||||
assert result.startswith("inline; filename*=UTF-8''")
|
||||
|
||||
def test_empty_filename(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("")
|
||||
assert result == 'attachment; filename=""'
|
||||
|
||||
|
||||
# -- _create_streaming_response tests --
|
||||
|
||||
|
||||
class TestCreateStreamingResponse:
|
||||
def test_attachment_disposition_by_default(self):
|
||||
from backend.api.features.workspace.routes import _create_streaming_response
|
||||
|
||||
file = _make_file(name="data.bin", mime_type="application/octet-stream")
|
||||
response = _create_streaming_response(b"binary-data", file)
|
||||
assert (
|
||||
response.headers["Content-Disposition"] == 'attachment; filename="data.bin"'
|
||||
)
|
||||
assert response.headers["Content-Type"] == "application/octet-stream"
|
||||
assert response.headers["Content-Length"] == "11"
|
||||
assert response.body == b"binary-data"
|
||||
|
||||
def test_inline_disposition(self):
|
||||
from backend.api.features.workspace.routes import _create_streaming_response
|
||||
|
||||
file = _make_file(name="photo.png", mime_type="image/png")
|
||||
response = _create_streaming_response(b"\x89PNG", file, inline=True)
|
||||
assert response.headers["Content-Disposition"] == 'inline; filename="photo.png"'
|
||||
assert response.headers["Content-Type"] == "image/png"
|
||||
|
||||
def test_inline_sanitizes_filename(self):
|
||||
from backend.api.features.workspace.routes import _create_streaming_response
|
||||
|
||||
file = _make_file(name='evil"\r\n.txt', mime_type="text/plain")
|
||||
response = _create_streaming_response(b"data", file, inline=True)
|
||||
assert "\r" not in response.headers["Content-Disposition"]
|
||||
assert "\n" not in response.headers["Content-Disposition"]
|
||||
assert "inline" in response.headers["Content-Disposition"]
|
||||
|
||||
def test_content_length_matches_body(self):
|
||||
from backend.api.features.workspace.routes import _create_streaming_response
|
||||
|
||||
content = b"x" * 1000
|
||||
file = _make_file(name="big.bin", mime_type="application/octet-stream")
|
||||
response = _create_streaming_response(content, file)
|
||||
assert response.headers["Content-Length"] == "1000"
|
||||
|
||||
|
||||
# -- create_file_download_response tests --
|
||||
|
||||
|
||||
class TestCreateFileDownloadResponse:
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_storage_returns_streaming_response(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.return_value = b"file contents"
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(
|
||||
storage_path="local://uploads/test.txt",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
response = await create_file_download_response(file)
|
||||
assert response.status_code == 200
|
||||
assert response.body == b"file contents"
|
||||
assert "attachment" in response.headers["Content-Disposition"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_storage_inline(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.return_value = b"\x89PNG"
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(
|
||||
storage_path="local://uploads/photo.png",
|
||||
mime_type="image/png",
|
||||
name="photo.png",
|
||||
)
|
||||
response = await create_file_download_response(file, inline=True)
|
||||
assert "inline" in response.headers["Content-Disposition"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gcs_redirect(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.get_download_url.return_value = (
|
||||
"https://storage.googleapis.com/signed-url"
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(storage_path="gcs://bucket/file.pdf")
|
||||
response = await create_file_download_response(file)
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers["location"] == "https://storage.googleapis.com/signed-url"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gcs_api_fallback_streams_directly(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.get_download_url.return_value = "/api/fallback"
|
||||
mock_storage.retrieve.return_value = b"fallback content"
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(storage_path="gcs://bucket/file.txt")
|
||||
response = await create_file_download_response(file)
|
||||
assert response.status_code == 200
|
||||
assert response.body == b"fallback content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gcs_signed_url_failure_falls_back_to_streaming(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.get_download_url.side_effect = RuntimeError("GCS error")
|
||||
mock_storage.retrieve.return_value = b"streamed"
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(storage_path="gcs://bucket/file.txt")
|
||||
response = await create_file_download_response(file)
|
||||
assert response.status_code == 200
|
||||
assert response.body == b"streamed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gcs_total_failure_raises(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.get_download_url.side_effect = RuntimeError("GCS error")
|
||||
mock_storage.retrieve.side_effect = RuntimeError("Also failed")
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(storage_path="gcs://bucket/file.txt")
|
||||
with pytest.raises(RuntimeError, match="Also failed"):
|
||||
await create_file_download_response(file)
|
||||
|
||||
@@ -17,7 +17,6 @@ from fastapi.routing import APIRoute
|
||||
from prisma.errors import PrismaError
|
||||
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.diagnostics_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.platform_cost_routes
|
||||
import backend.api.features.admin.rate_limit_admin_routes
|
||||
@@ -32,9 +31,7 @@ import backend.api.features.library.routes
|
||||
import backend.api.features.mcp.routes as mcp_routes
|
||||
import backend.api.features.oauth
|
||||
import backend.api.features.otto.routes
|
||||
import backend.api.features.platform_linking.routes
|
||||
import backend.api.features.postmark.postmark
|
||||
import backend.api.features.push.routes as push_routes
|
||||
import backend.api.features.store.model
|
||||
import backend.api.features.store.routes
|
||||
import backend.api.features.v1
|
||||
@@ -42,7 +39,6 @@ import backend.api.features.workspace.routes as workspace_routes
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.redis_client
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.util.service
|
||||
@@ -97,8 +93,6 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
verify_auth_settings()
|
||||
|
||||
await backend.data.db.connect()
|
||||
# Eager connect to fail-fast if Redis is unreachable.
|
||||
await backend.data.redis_client.get_redis_async()
|
||||
|
||||
# Configure thread pool for FastAPI sync operation performance
|
||||
# CRITICAL: FastAPI automatically runs ALL sync functions in this thread pool:
|
||||
@@ -150,18 +144,7 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
except Exception as e:
|
||||
logger.warning(f"Error shutting down workspace storage: {e}")
|
||||
|
||||
# Each cleanup is wrapped so one failure doesn't block the rest. The
|
||||
# Redis close in particular silences asyncio's "Unclosed ClusterNode"
|
||||
# GC warning at interpreter shutdown.
|
||||
try:
|
||||
await backend.data.redis_client.disconnect_async()
|
||||
except Exception:
|
||||
logger.warning("redis_client.disconnect_async failed", exc_info=True)
|
||||
|
||||
try:
|
||||
await backend.data.db.disconnect()
|
||||
except Exception:
|
||||
logger.warning("db.disconnect failed", exc_info=True)
|
||||
await backend.data.db.disconnect()
|
||||
|
||||
|
||||
def custom_generate_unique_id(route: APIRoute):
|
||||
@@ -337,11 +320,6 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/credits",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.diagnostics_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.execution_analytics_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
@@ -394,16 +372,6 @@ app.include_router(
|
||||
tags=["oauth"],
|
||||
prefix="/api/oauth",
|
||||
)
|
||||
app.include_router(
|
||||
push_routes.router,
|
||||
tags=["push"],
|
||||
prefix="/api/push",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.platform_linking.routes.router,
|
||||
tags=["platform-linking"],
|
||||
prefix="/api/platform-linking",
|
||||
)
|
||||
|
||||
app.mount("/external-api", external_api)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Protocol
|
||||
@@ -16,12 +17,14 @@ from backend.api.model import (
|
||||
WSSubscribeGraphExecutionsRequest,
|
||||
)
|
||||
from backend.api.utils.cors import build_cors_params
|
||||
from backend.data import db, redis_client
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.notification_bus import AsyncRedisNotificationEventBus
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.monitoring.instrumentation import (
|
||||
instrument_fastapi,
|
||||
update_websocket_connections,
|
||||
)
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.service import AppProcess
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
@@ -31,24 +34,10 @@ settings = Settings()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Prisma is needed to resolve graph_id from graph_exec_id on subscribe.
|
||||
await db.connect()
|
||||
# Eager connect to fail-fast if Redis is unreachable.
|
||||
await redis_client.get_redis_async()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Each cleanup is wrapped so one failure doesn't block the rest. The
|
||||
# Redis close silences asyncio's "Unclosed ClusterNode" GC warning at
|
||||
# interpreter shutdown.
|
||||
try:
|
||||
await redis_client.disconnect_async()
|
||||
except Exception:
|
||||
logger.warning("redis_client.disconnect_async failed", exc_info=True)
|
||||
try:
|
||||
await db.disconnect()
|
||||
except Exception:
|
||||
logger.warning("db.disconnect failed", exc_info=True)
|
||||
manager = get_connection_manager()
|
||||
fut = asyncio.create_task(event_broadcaster(manager))
|
||||
fut.add_done_callback(lambda _: logger.info("Event broadcaster stopped"))
|
||||
yield
|
||||
|
||||
|
||||
docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
|
||||
@@ -72,6 +61,31 @@ def get_connection_manager():
|
||||
return _connection_manager
|
||||
|
||||
|
||||
@continuous_retry()
|
||||
async def event_broadcaster(manager: ConnectionManager):
|
||||
execution_bus = AsyncRedisExecutionEventBus()
|
||||
notification_bus = AsyncRedisNotificationEventBus()
|
||||
|
||||
try:
|
||||
|
||||
async def execution_worker():
|
||||
async for event in execution_bus.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
|
||||
async def notification_worker():
|
||||
async for notification in notification_bus.listen("*"):
|
||||
await manager.send_notification(
|
||||
user_id=notification.user_id,
|
||||
payload=notification.payload,
|
||||
)
|
||||
|
||||
await asyncio.gather(execution_worker(), notification_worker())
|
||||
finally:
|
||||
# Ensure PubSub connections are closed on any exit to prevent leaks
|
||||
await execution_bus.close()
|
||||
await notification_bus.close()
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
if not settings.config.enable_auth:
|
||||
return DEFAULT_USER_ID
|
||||
@@ -283,21 +297,6 @@ async def websocket_router(
|
||||
).model_dump_json()
|
||||
)
|
||||
continue
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Subscription rejected for user #%s on '%s': %s",
|
||||
user_id,
|
||||
message.method.value,
|
||||
e,
|
||||
)
|
||||
await websocket.send_text(
|
||||
WSMessage(
|
||||
method=WSMethod.ERROR,
|
||||
success=False,
|
||||
error=str(e),
|
||||
).model_dump_json()
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error while handling '{message.method.value}' message "
|
||||
@@ -322,13 +321,9 @@ async def websocket_router(
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect_socket(websocket, user_id=user_id)
|
||||
logger.debug("WebSocket client disconnected")
|
||||
except Exception:
|
||||
logger.exception(f"Unexpected error in websocket_router for user #{user_id}")
|
||||
finally:
|
||||
# Always release subscription pumps + Redis connections, regardless of how
|
||||
# the loop exited — otherwise non-WebSocketDisconnect failures leak both.
|
||||
await manager.disconnect_socket(websocket, user_id=user_id)
|
||||
update_websocket_connections(user_id, -1)
|
||||
|
||||
|
||||
|
||||
@@ -44,12 +44,9 @@ def test_websocket_server_uses_cors_helper(mocker) -> None:
|
||||
"backend.api.ws_api.build_cors_params", return_value=cors_params
|
||||
)
|
||||
|
||||
with (
|
||||
override_config(
|
||||
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
|
||||
),
|
||||
override_config(settings, "app_env", AppEnvironment.LOCAL),
|
||||
):
|
||||
with override_config(
|
||||
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
|
||||
), override_config(settings, "app_env", AppEnvironment.LOCAL):
|
||||
WebsocketServer().run()
|
||||
|
||||
build_cors.assert_called_once_with(
|
||||
@@ -68,12 +65,9 @@ def test_websocket_server_uses_cors_helper(mocker) -> None:
|
||||
def test_websocket_server_blocks_localhost_in_production(mocker) -> None:
|
||||
mocker.patch("backend.api.ws_api.uvicorn.run")
|
||||
|
||||
with (
|
||||
override_config(
|
||||
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
|
||||
),
|
||||
override_config(settings, "app_env", AppEnvironment.PRODUCTION),
|
||||
):
|
||||
with override_config(
|
||||
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
|
||||
), override_config(settings, "app_env", AppEnvironment.PRODUCTION):
|
||||
with pytest.raises(ValueError):
|
||||
WebsocketServer().run()
|
||||
|
||||
@@ -296,232 +290,7 @@ async def test_handle_unsubscribe_missing_data(
|
||||
message=message,
|
||||
)
|
||||
|
||||
mock_manager.unsubscribe_graph_exec.assert_not_called()
|
||||
mock_manager._unsubscribe.assert_not_called()
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
|
||||
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
|
||||
|
||||
|
||||
# ---------- Per-graph subscribe branch ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_subscribe_graph_execs_branch(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||
) -> None:
|
||||
"""The SUBSCRIBE_GRAPH_EXECS branch must route to subscribe_graph_execs,
|
||||
not subscribe_graph_exec — regression guard for the aggregate channel."""
|
||||
message = WSMessage(
|
||||
method=WSMethod.SUBSCRIBE_GRAPH_EXECS,
|
||||
data={"graph_id": "graph-abc"},
|
||||
)
|
||||
mock_manager.subscribe_graph_execs.return_value = (
|
||||
"user-1|graph#graph-abc|executions"
|
||||
)
|
||||
|
||||
await handle_subscribe(
|
||||
connection_manager=cast(ConnectionManager, mock_manager),
|
||||
websocket=cast(WebSocket, mock_websocket),
|
||||
user_id="user-1",
|
||||
message=message,
|
||||
)
|
||||
|
||||
mock_manager.subscribe_graph_execs.assert_called_once_with(
|
||||
user_id="user-1",
|
||||
graph_id="graph-abc",
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
mock_manager.subscribe_graph_exec.assert_not_called()
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert (
|
||||
'"method":"subscribe_graph_executions"'
|
||||
in mock_websocket.send_text.call_args[0][0]
|
||||
)
|
||||
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_subscribe_rejects_unrelated_method(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||
) -> None:
|
||||
"""handle_subscribe must raise for methods that aren't SUBSCRIBE_*."""
|
||||
import pytest as _pytest
|
||||
|
||||
message = WSMessage(
|
||||
method=WSMethod.HEARTBEAT,
|
||||
data={"graph_exec_id": "x"},
|
||||
)
|
||||
|
||||
with _pytest.raises(ValueError):
|
||||
await handle_subscribe(
|
||||
connection_manager=cast(ConnectionManager, mock_manager),
|
||||
websocket=cast(WebSocket, mock_websocket),
|
||||
user_id="user-1",
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
# ---------- authenticate_websocket branches ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_websocket_missing_token_closes_4001(mocker) -> None:
|
||||
from backend.api.ws_api import authenticate_websocket
|
||||
|
||||
mocker.patch.object(settings.config, "enable_auth", True)
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.query_params = {}
|
||||
|
||||
user_id = await authenticate_websocket(ws)
|
||||
|
||||
ws.close.assert_awaited_once()
|
||||
assert ws.close.call_args.kwargs["code"] == 4001
|
||||
assert user_id == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_websocket_invalid_token_closes_4003(mocker) -> None:
|
||||
from backend.api.ws_api import authenticate_websocket
|
||||
|
||||
mocker.patch.object(settings.config, "enable_auth", True)
|
||||
mocker.patch(
|
||||
"backend.api.ws_api.parse_jwt_token", side_effect=ValueError("bad token")
|
||||
)
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.query_params = {"token": "abc"}
|
||||
|
||||
user_id = await authenticate_websocket(ws)
|
||||
|
||||
ws.close.assert_awaited_once()
|
||||
assert ws.close.call_args.kwargs["code"] == 4003
|
||||
assert user_id == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_websocket_missing_sub_closes_4002(mocker) -> None:
|
||||
from backend.api.ws_api import authenticate_websocket
|
||||
|
||||
mocker.patch.object(settings.config, "enable_auth", True)
|
||||
mocker.patch("backend.api.ws_api.parse_jwt_token", return_value={"not_sub": "x"})
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.query_params = {"token": "abc"}
|
||||
|
||||
user_id = await authenticate_websocket(ws)
|
||||
|
||||
ws.close.assert_awaited_once()
|
||||
assert ws.close.call_args.kwargs["code"] == 4002
|
||||
assert user_id == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_websocket_happy_path_returns_sub(mocker) -> None:
|
||||
from backend.api.ws_api import authenticate_websocket
|
||||
|
||||
mocker.patch.object(settings.config, "enable_auth", True)
|
||||
mocker.patch("backend.api.ws_api.parse_jwt_token", return_value={"sub": "user-X"})
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.query_params = {"token": "abc"}
|
||||
|
||||
user_id = await authenticate_websocket(ws)
|
||||
|
||||
assert user_id == "user-X"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_websocket_auth_disabled_returns_default(mocker) -> None:
|
||||
from backend.api.ws_api import authenticate_websocket
|
||||
|
||||
mocker.patch.object(settings.config, "enable_auth", False)
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.query_params = {}
|
||||
|
||||
user_id = await authenticate_websocket(ws)
|
||||
|
||||
assert user_id == DEFAULT_USER_ID
|
||||
|
||||
|
||||
# ---------- get_connection_manager singleton ----------
|
||||
|
||||
|
||||
def test_get_connection_manager_singleton() -> None:
|
||||
"""Repeated calls must return the same ConnectionManager — the WS router
|
||||
depends on a single process-wide subscription table."""
|
||||
import backend.api.ws_api as ws_api
|
||||
|
||||
ws_api._connection_manager = None
|
||||
a = ws_api.get_connection_manager()
|
||||
b = ws_api.get_connection_manager()
|
||||
assert a is b
|
||||
assert isinstance(a, ConnectionManager)
|
||||
|
||||
|
||||
# ---------- Lifespan: Prisma connect/disconnect ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_connects_and_disconnects_prisma(mocker) -> None:
|
||||
"""Lifespan must both connect() and disconnect() db — the subscribe path
|
||||
resolves graph_id via Prisma so a missing connect() is the regression bug."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.api.ws_api import lifespan
|
||||
|
||||
mock_db = mocker.patch("backend.api.ws_api.db")
|
||||
mock_db.connect = AsyncMock()
|
||||
mock_db.disconnect = AsyncMock()
|
||||
|
||||
dummy_app = FastAPI()
|
||||
async with lifespan(dummy_app):
|
||||
mock_db.connect.assert_awaited_once()
|
||||
mock_db.disconnect.assert_not_called()
|
||||
mock_db.disconnect.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_still_disconnects_on_exception(mocker) -> None:
|
||||
"""If the app raises inside the yield, Prisma must still disconnect."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.api.ws_api import lifespan
|
||||
|
||||
mock_db = mocker.patch("backend.api.ws_api.db")
|
||||
mock_db.connect = AsyncMock()
|
||||
mock_db.disconnect = AsyncMock()
|
||||
|
||||
dummy_app = FastAPI()
|
||||
|
||||
class _Boom(Exception):
|
||||
pass
|
||||
|
||||
with pytest.raises(_Boom):
|
||||
async with lifespan(dummy_app):
|
||||
raise _Boom()
|
||||
|
||||
mock_db.disconnect.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------- Health endpoint ----------
|
||||
|
||||
|
||||
def test_health_endpoint_returns_ok() -> None:
|
||||
# TestClient triggers lifespan — stub it out so Prisma isn't hit.
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import backend.api.ws_api as ws_api
|
||||
|
||||
@asynccontextmanager
|
||||
async def _noop_lifespan(app):
|
||||
yield
|
||||
|
||||
# Replace the app-level lifespan temporarily.
|
||||
real_router_lifespan = ws_api.app.router.lifespan_context
|
||||
ws_api.app.router.lifespan_context = _noop_lifespan
|
||||
try:
|
||||
with TestClient(ws_api.app) as client:
|
||||
r = client.get("/")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"status": "healthy"}
|
||||
finally:
|
||||
ws_api.app.router.lifespan_context = real_router_lifespan
|
||||
|
||||
@@ -38,23 +38,19 @@ def main(**kwargs):
|
||||
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.api.ws_api import WebsocketServer
|
||||
from backend.copilot.bot.app import CoPilotChatBridge
|
||||
from backend.copilot.executor.manager import CoPilotExecutor
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
from backend.executor import ExecutionManager, Scheduler
|
||||
from backend.notifications import NotificationManager
|
||||
from backend.platform_linking.manager import PlatformLinkingManager
|
||||
|
||||
run_processes(
|
||||
DatabaseManager().set_log_level("warning"),
|
||||
Scheduler(),
|
||||
NotificationManager(),
|
||||
PlatformLinkingManager(),
|
||||
WebsocketServer(),
|
||||
AgentServer(),
|
||||
ExecutionManager(),
|
||||
CoPilotExecutor(),
|
||||
CoPilotChatBridge(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -96,64 +96,27 @@ class BlockCategory(Enum):
|
||||
|
||||
|
||||
class BlockCostType(str, Enum):
|
||||
# RUN : cost_amount credits per run.
|
||||
# BYTE : cost_amount credits per byte of input data.
|
||||
# SECOND : cost_amount credits per cost_divisor walltime seconds.
|
||||
# ITEMS : cost_amount credits per cost_divisor items (from stats).
|
||||
# COST_USD : cost_amount credits per USD of stats.provider_cost.
|
||||
# TOKENS : per-(model, provider) rate table lookup; see TOKEN_COST.
|
||||
RUN = "run"
|
||||
BYTE = "byte"
|
||||
SECOND = "second"
|
||||
ITEMS = "items"
|
||||
COST_USD = "cost_usd"
|
||||
TOKENS = "tokens"
|
||||
|
||||
@property
|
||||
def is_dynamic(self) -> bool:
|
||||
"""Real charge is computed post-flight from stats.
|
||||
|
||||
Dynamic types (SECOND/ITEMS/COST_USD/TOKENS) return 0 pre-flight and
|
||||
settle against stats via charge_reconciled_usage once the block runs.
|
||||
"""
|
||||
return self in _DYNAMIC_COST_TYPES
|
||||
|
||||
|
||||
_DYNAMIC_COST_TYPES: frozenset[BlockCostType] = frozenset(
|
||||
{
|
||||
BlockCostType.SECOND,
|
||||
BlockCostType.ITEMS,
|
||||
BlockCostType.COST_USD,
|
||||
BlockCostType.TOKENS,
|
||||
}
|
||||
)
|
||||
RUN = "run" # cost X credits per run
|
||||
BYTE = "byte" # cost X credits per byte
|
||||
SECOND = "second" # cost X credits per second
|
||||
|
||||
|
||||
class BlockCost(BaseModel):
|
||||
cost_amount: int
|
||||
cost_filter: BlockInput
|
||||
cost_type: BlockCostType
|
||||
# cost_divisor: interpret cost_amount as "credits per cost_divisor units".
|
||||
# Only meaningful for SECOND / ITEMS. TOKENS routes through TOKEN_COST
|
||||
# rate tables (per-model input/output/cache pricing) and ignores
|
||||
# cost_divisor entirely. Defaults to 1 so existing RUN/BYTE entries stay
|
||||
# point-wise. Example: cost_amount=1, cost_divisor=10 under SECOND means
|
||||
# "1 credit per 10 seconds of walltime".
|
||||
cost_divisor: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_amount: int,
|
||||
cost_type: BlockCostType = BlockCostType.RUN,
|
||||
cost_filter: Optional[BlockInput] = None,
|
||||
cost_divisor: int = 1,
|
||||
**data: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
cost_amount=cost_amount,
|
||||
cost_filter=cost_filter or {},
|
||||
cost_type=cost_type,
|
||||
cost_divisor=max(1, cost_divisor),
|
||||
**data,
|
||||
)
|
||||
|
||||
@@ -205,31 +168,9 @@ class BlockSchema(BaseModel):
|
||||
return cls.cached_jsonschema
|
||||
|
||||
@classmethod
|
||||
def validate_data(
|
||||
cls,
|
||||
data: BlockInput,
|
||||
exclude_fields: set[str] | None = None,
|
||||
) -> str | None:
|
||||
schema = cls.jsonschema()
|
||||
if exclude_fields:
|
||||
# Drop the excluded fields from both the properties and the
|
||||
# ``required`` list so jsonschema doesn't flag them as missing.
|
||||
# Used by the dry-run path to skip credentials validation while
|
||||
# still validating the remaining block inputs.
|
||||
schema = {
|
||||
**schema,
|
||||
"properties": {
|
||||
k: v
|
||||
for k, v in schema.get("properties", {}).items()
|
||||
if k not in exclude_fields
|
||||
},
|
||||
"required": [
|
||||
r for r in schema.get("required", []) if r not in exclude_fields
|
||||
],
|
||||
}
|
||||
data = {k: v for k, v in data.items() if k not in exclude_fields}
|
||||
def validate_data(cls, data: BlockInput) -> str | None:
|
||||
return json.validate_with_jsonschema(
|
||||
schema=schema,
|
||||
schema=cls.jsonschema(),
|
||||
data={k: v for k, v in data.items() if v is not None},
|
||||
)
|
||||
|
||||
@@ -370,8 +311,6 @@ class BlockSchema(BaseModel):
|
||||
"credentials_provider": [config.get("provider", "google")],
|
||||
"credentials_types": [config.get("type", "oauth2")],
|
||||
"credentials_scopes": config.get("scopes"),
|
||||
"is_auto_credential": True,
|
||||
"input_field_name": info["field_name"],
|
||||
}
|
||||
result[kwarg_name] = CredentialsFieldInfo.model_validate(
|
||||
auto_schema, by_alias=True
|
||||
@@ -482,6 +421,19 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
|
||||
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
_optimized_description: ClassVar[str | None] = None
|
||||
|
||||
def extra_runtime_cost(self, execution_stats: NodeExecutionStats) -> int:
|
||||
"""Return extra runtime cost to charge after this block run completes.
|
||||
|
||||
Called by the executor after a block finishes with COMPLETED status.
|
||||
The return value is the number of additional base-cost credits to
|
||||
charge beyond the single credit already collected by charge_usage
|
||||
at the start of execution. Defaults to 0 (no extra charges).
|
||||
|
||||
Override in blocks (e.g. OrchestratorBlock) that make multiple LLM
|
||||
calls within one run and should be billed per call.
|
||||
"""
|
||||
return 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str = "",
|
||||
@@ -765,16 +717,11 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
# (e.g. AgentExecutorBlock) get proper input validation.
|
||||
is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False)
|
||||
if is_dry_run:
|
||||
# Credential fields may be absent (LLM-built agents often skip
|
||||
# wiring them) or nullified earlier in the pipeline. Validate
|
||||
# the non-credential inputs against a schema with those fields
|
||||
# excluded — stripping only the data while keeping them in the
|
||||
# ``required`` list would falsely report ``'credentials' is a
|
||||
# required property``.
|
||||
cred_field_names = set(self.input_schema.get_credentials_fields().keys())
|
||||
if error := self.input_schema.validate_data(
|
||||
input_data, exclude_fields=cred_field_names
|
||||
):
|
||||
non_cred_data = {
|
||||
k: v for k, v in input_data.items() if k not in cred_field_names
|
||||
}
|
||||
if error := self.input_schema.validate_data(non_cred_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
@@ -788,61 +735,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Ensure auto-credential kwargs are present before we hand off to
|
||||
# run(). A missing auto-credential means the upstream field (e.g.
|
||||
# a Google Drive picker) didn't embed a _credentials_id, or the
|
||||
# executor couldn't resolve it. Without this guard, run() would
|
||||
# crash with a TypeError (missing required kwarg) or an opaque
|
||||
# AttributeError deep inside the provider SDK.
|
||||
#
|
||||
# Only raise when the field is ALSO not populated in input_data.
|
||||
# ``_acquire_auto_credentials`` intentionally skips setting the
|
||||
# kwarg in two legitimate cases — ``_credentials_id`` is ``None``
|
||||
# (chained from upstream) or the field is missing from
|
||||
# ``input_data`` at prep time (connected from upstream block).
|
||||
# In both cases the upstream block is expected to populate the
|
||||
# field value by execute time; raising here would break the
|
||||
# documented ``AgentGoogleDriveFileInputBlock`` chaining pattern.
|
||||
# Dry-run skips because the executor intentionally runs blocks
|
||||
# without resolved creds for schema validation.
|
||||
if not is_dry_run:
|
||||
for (
|
||||
kwarg_name,
|
||||
info,
|
||||
) in self.input_schema.get_auto_credentials_fields().items():
|
||||
kwargs.setdefault(kwarg_name, None)
|
||||
if kwargs[kwarg_name] is not None:
|
||||
continue
|
||||
# Upstream-chained pattern: the field was populated by a
|
||||
# prior node (e.g. AgentGoogleDriveFileInputBlock) whose
|
||||
# output carries a resolved ``_credentials_id``.
|
||||
# ``_acquire_auto_credentials`` deliberately doesn't set
|
||||
# the kwarg in that case because the value isn't available
|
||||
# at prep time; the executor fills it in before we reach
|
||||
# ``_execute``. Trust it if the ``_credentials_id`` KEY
|
||||
# is present — its value may be explicitly ``None`` in
|
||||
# the chained case (see sentry thread
|
||||
# PRRT_kwDOJKSTjM58sJfA). Checking truthiness here would
|
||||
# falsely preempt run() for every valid chained graph
|
||||
# that ships ``_credentials_id=None`` in the picker
|
||||
# object. Mirror ``_acquire_auto_credentials``'s own
|
||||
# skip rule, which treats ``cred_id is None`` as a
|
||||
# chained-skip signal.
|
||||
field_name = info["field_name"]
|
||||
field_value = input_data.get(field_name)
|
||||
if isinstance(field_value, dict) and "_credentials_id" in field_value:
|
||||
continue
|
||||
raise BlockExecutionError(
|
||||
message=(
|
||||
f"Missing credentials for '{kwarg_name}'. "
|
||||
"Select a file via the picker (which carries "
|
||||
"its credentials), or connect credentials for "
|
||||
"this block."
|
||||
),
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Use the validated input data
|
||||
async for output_name, output_data in self.run(
|
||||
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
"""Provider descriptions for services that don't yet have their own ``_config.py``.
|
||||
|
||||
Every provider in ``_STATIC_PROVIDER_CONFIGS`` below is declared here because
|
||||
its block code currently lives either in a single shared file (e.g. the 8 LLM
|
||||
providers in ``blocks/llm.py``) or in a single-file block that has no dedicated
|
||||
directory (e.g. ``blocks/reddit.py``).
|
||||
|
||||
This file gets loaded by the block auto-loader in ``blocks/__init__.py``
|
||||
(``rglob("*.py")`` picks it up) so the ``ProviderBuilder(...).build()`` calls
|
||||
run at startup and populate ``AutoRegistry`` before the first API request.
|
||||
|
||||
**Migration path:** when a provider graduates into its own directory with a
|
||||
proper ``_config.py`` (following the SDK pattern, e.g. ``blocks/linear/_config.py``),
|
||||
delete its entry here. The metadata will still be served by
|
||||
``GET /integrations/providers`` — it just moves to live next to the provider's
|
||||
auth and webhook config.
|
||||
"""
|
||||
|
||||
from backend.data.model import CredentialsType
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
_STATIC_PROVIDER_CONFIGS: dict[str, tuple[str, tuple[CredentialsType, ...]]] = {
|
||||
# LLM providers that share blocks/llm.py
|
||||
"aiml_api": ("Unified access to 100+ AI models", ("api_key",)),
|
||||
"anthropic": ("Claude language models", ("api_key",)),
|
||||
"groq": ("Fast LLM inference", ("api_key",)),
|
||||
"llama_api": ("Llama model hosting", ("api_key",)),
|
||||
"ollama": ("Run open-source LLMs locally", ("api_key",)),
|
||||
"open_router": ("One API for every LLM", ("api_key",)),
|
||||
"openai": ("GPT models and embeddings", ("api_key",)),
|
||||
"v0": ("AI-generated UI components", ("api_key",)),
|
||||
# Single-file providers (one provider per standalone blocks/*.py file)
|
||||
"d_id": ("AI avatar and video generation", ("api_key",)),
|
||||
"e2b": ("Sandboxed code execution", ("api_key",)),
|
||||
"google_maps": ("Places, directions, geocoding", ("api_key",)),
|
||||
"http": ("Generic HTTP requests", ("api_key", "host_scoped")),
|
||||
"ideogram": ("Text-to-image generation", ("api_key",)),
|
||||
"medium": ("Publish stories and posts", ("api_key",)),
|
||||
"mem0": ("Long-term memory for agents", ("api_key",)),
|
||||
"openweathermap": ("Weather data and forecasts", ("api_key",)),
|
||||
"pinecone": ("Managed vector database", ("api_key",)),
|
||||
"reddit": ("Subreddits, posts, and comments", ("oauth2",)),
|
||||
"revid": ("AI-generated short-form video", ("api_key",)),
|
||||
"screenshotone": ("Automated website screenshots", ("api_key",)),
|
||||
"smtp": ("Send email via SMTP", ("user_password",)),
|
||||
"unreal_speech": ("Low-cost text-to-speech", ("api_key",)),
|
||||
"webshare_proxy": ("Rotating proxies for scraping", ("api_key",)),
|
||||
}
|
||||
|
||||
for _name, (_description, _auth_types) in _STATIC_PROVIDER_CONFIGS.items():
|
||||
(
|
||||
ProviderBuilder(_name)
|
||||
.with_description(_description)
|
||||
.with_supported_auth_types(*_auth_types)
|
||||
.build()
|
||||
)
|
||||
@@ -171,10 +171,7 @@ class AgentExecutorBlock(Block):
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
# Sub-graph already debited each of its own nodes; we
|
||||
# roll up its total so graph_stats.cost reflects the
|
||||
# full sub-graph spend.
|
||||
reconciled_cost_delta=(event.stats.cost if event.stats else 0),
|
||||
extra_cost=event.stats.cost if event.stats else 0,
|
||||
extra_steps=event.stats.node_exec_count if event.stats else 0,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -4,17 +4,11 @@ Shared configuration for all AgentMail blocks.
|
||||
|
||||
from agentmail import AsyncAgentMail
|
||||
|
||||
from backend.sdk import APIKeyCredentials, BlockCostType, ProviderBuilder, SecretStr
|
||||
from backend.sdk import APIKeyCredentials, ProviderBuilder, SecretStr
|
||||
|
||||
# AgentMail is in beta with no published paid tier yet, but ~37 blocks
|
||||
# without any BLOCK_COSTS entry means they currently execute wallet-free.
|
||||
# 1 cr/call is a conservative interim floor so no AgentMail work leaks
|
||||
# past billing. Revisit once AgentMail publishes usage-based pricing.
|
||||
agent_mail = (
|
||||
ProviderBuilder("agent_mail")
|
||||
.with_description("Managed email accounts for agents")
|
||||
.with_api_key("AGENTMAIL_API_KEY", "AgentMail API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ from ._webhook import AirtableWebhookManager
|
||||
# Configure the Airtable provider with API key authentication
|
||||
airtable = (
|
||||
ProviderBuilder("airtable")
|
||||
.with_description("Bases, tables, and records")
|
||||
.with_api_key("AIRTABLE_API_KEY", "Airtable Personal Access Token")
|
||||
.with_webhook_manager(AirtableWebhookManager)
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
"""Provider registration for Apollo.
|
||||
|
||||
Registers the provider description shown in the settings integrations UI.
|
||||
Apollo doesn't use a full :class:`ProviderBuilder` chain (auth is set up in
|
||||
``_auth.py``), so this file only declares metadata.
|
||||
"""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
apollo = (
|
||||
ProviderBuilder("apollo")
|
||||
.with_description("Sales intelligence and prospecting")
|
||||
.with_supported_auth_types("api_key")
|
||||
.build()
|
||||
)
|
||||
@@ -7,7 +7,6 @@ import logging
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import field_validator
|
||||
from typing_extensions import TypedDict # Needed for Python <3.12 compatibility
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -18,14 +17,12 @@ from backend.blocks._base import (
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.copilot.permissions import (
|
||||
DISABLED_LEGACY_TOOL_NAMES,
|
||||
CopilotPermissions,
|
||||
ToolName,
|
||||
all_known_tool_names,
|
||||
validate_block_identifiers,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import ExecutionContext
|
||||
@@ -35,36 +32,9 @@ logger = logging.getLogger(__name__)
|
||||
# Block ID shared between autopilot.py and copilot prompting.py.
|
||||
AUTOPILOT_BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
|
||||
# Identifiers used when registering an AutoPilotBlock turn with the
|
||||
# stream registry — distinguishes block-originated turns from sub-session
|
||||
# or HTTP SSE turns in logs / observability.
|
||||
_AUTOPILOT_TOOL_CALL_ID = "autopilot_block"
|
||||
_AUTOPILOT_TOOL_NAME = "autopilot_block"
|
||||
|
||||
# Ceiling on how long AutoPilotBlock.execute_copilot will wait for the
|
||||
# enqueued turn's terminal event. Graph blocks run synchronously from
|
||||
# the caller's perspective so we wait effectively as long as needed; 6h
|
||||
# matches the previous abandoned-task cap and is much longer than any
|
||||
# legitimate AutoPilot turn.
|
||||
_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS = 6 * 60 * 60 # 6 hours
|
||||
|
||||
|
||||
class SubAgentRecursionError(BlockExecutionError):
|
||||
"""Raised when the AutoPilot sub-agent nesting depth limit is exceeded.
|
||||
|
||||
Inherits :class:`BlockExecutionError` — this is a known, handled
|
||||
runtime failure at the block level (caller nested AutoPilotBlocks
|
||||
beyond the configured limit). Surfaces with the block_name /
|
||||
block_id the block framework expects, instead of being wrapped in
|
||||
``BlockUnknownError``.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(
|
||||
message=message,
|
||||
block_name="AutoPilotBlock",
|
||||
block_id=AUTOPILOT_BLOCK_ID,
|
||||
)
|
||||
class SubAgentRecursionError(RuntimeError):
|
||||
"""Raised when the sub-agent nesting depth limit is exceeded."""
|
||||
|
||||
|
||||
class ToolCallEntry(TypedDict):
|
||||
@@ -200,13 +170,6 @@ class AutoPilotBlock(Block):
|
||||
# timeouts internally; wrapping with asyncio.timeout corrupts the
|
||||
# SDK's internal stream (see service.py CRITICAL comment).
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
def strip_disabled_legacy_tools(cls, tools: Any) -> Any:
|
||||
if not isinstance(tools, list):
|
||||
return tools
|
||||
return [tool for tool in tools if tool not in DISABLED_LEGACY_TOOL_NAMES]
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
"""Output schema for the AutoPilot block."""
|
||||
|
||||
@@ -305,15 +268,11 @@ class AutoPilotBlock(Block):
|
||||
user_id: str,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> tuple[str, list[ToolCallEntry], str, str, TokenUsage]:
|
||||
"""Invoke the copilot on the copilot_executor queue and aggregate the
|
||||
result.
|
||||
"""Invoke the copilot and collect all stream results.
|
||||
|
||||
Delegates to :func:`run_copilot_turn_via_queue` — the shared
|
||||
primitive used by ``run_sub_session`` too — which creates the
|
||||
stream_registry meta record, enqueues the job, and waits on the
|
||||
Redis stream for the terminal event. Any available
|
||||
copilot_executor worker picks up the job, so this call survives
|
||||
the graph-executor worker dying mid-turn (RabbitMQ redelivers).
|
||||
Delegates to :func:`collect_copilot_response` — the shared helper that
|
||||
consumes ``stream_chat_completion_sdk`` without wrapping it in an
|
||||
``asyncio.timeout`` (the SDK manages its own heartbeat-based timeouts).
|
||||
|
||||
Args:
|
||||
prompt: The user task/instruction.
|
||||
@@ -326,8 +285,8 @@ class AutoPilotBlock(Block):
|
||||
Returns:
|
||||
A tuple of (response_text, tool_calls, history_json, session_id, usage).
|
||||
"""
|
||||
from backend.copilot.sdk.session_waiter import (
|
||||
run_copilot_turn_via_queue, # avoid circular import
|
||||
from backend.copilot.sdk.collect import (
|
||||
collect_copilot_response, # avoid circular import
|
||||
)
|
||||
|
||||
tokens = _check_recursion(max_recursion_depth)
|
||||
@@ -340,35 +299,14 @@ class AutoPilotBlock(Block):
|
||||
if system_context:
|
||||
effective_prompt = f"[System Context: {system_context}]\n\n{prompt}"
|
||||
|
||||
outcome, result = await run_copilot_turn_via_queue(
|
||||
result = await collect_copilot_response(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=effective_prompt,
|
||||
# Graph block execution is synchronous from the caller's
|
||||
# perspective — wait effectively as long as needed. The
|
||||
# SDK enforces its own idle-based timeout inside the
|
||||
# stream_registry pipeline.
|
||||
timeout=_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS,
|
||||
user_id=user_id,
|
||||
permissions=effective_permissions,
|
||||
tool_call_id=_AUTOPILOT_TOOL_CALL_ID,
|
||||
tool_name=_AUTOPILOT_TOOL_NAME,
|
||||
)
|
||||
if outcome == "failed":
|
||||
raise RuntimeError(
|
||||
"AutoPilot turn failed — see the session's transcript"
|
||||
)
|
||||
if outcome == "running":
|
||||
raise RuntimeError(
|
||||
"AutoPilot turn did not complete within "
|
||||
f"{_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS}s — session "
|
||||
f"{session_id}"
|
||||
)
|
||||
|
||||
# Build a lightweight conversation summary from the aggregated data.
|
||||
# When ``result.queued`` is True the prompt rode on an already-
|
||||
# in-flight turn (``run_copilot_turn_via_queue`` queued it and
|
||||
# waited on the existing turn's stream); the aggregated result
|
||||
# is still valid, so the same rendering path applies.
|
||||
# Build a lightweight conversation summary from streamed data.
|
||||
turn_messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": effective_prompt},
|
||||
]
|
||||
@@ -377,7 +315,7 @@ class AutoPilotBlock(Block):
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": result.response_text,
|
||||
"tool_calls": [tc.model_dump() for tc in result.tool_calls],
|
||||
"tool_calls": result.tool_calls,
|
||||
}
|
||||
)
|
||||
else:
|
||||
@@ -388,11 +326,11 @@ class AutoPilotBlock(Block):
|
||||
|
||||
tool_calls: list[ToolCallEntry] = [
|
||||
{
|
||||
"tool_call_id": tc.tool_call_id,
|
||||
"tool_name": tc.tool_name,
|
||||
"input": tc.input,
|
||||
"output": tc.output,
|
||||
"success": tc.success,
|
||||
"tool_call_id": tc["tool_call_id"],
|
||||
"tool_name": tc["tool_name"],
|
||||
"input": tc["input"],
|
||||
"output": tc["output"],
|
||||
"success": tc["success"],
|
||||
}
|
||||
for tc in result.tool_calls
|
||||
]
|
||||
|
||||
@@ -62,14 +62,6 @@ class TestBuildAndValidatePermissions:
|
||||
with pytest.raises(ValidationError, match="not_a_real_tool"):
|
||||
_make_input(tools=["not_a_real_tool"])
|
||||
|
||||
async def test_disabled_legacy_tool_is_accepted_and_removed(self):
|
||||
inp = _make_input(tools=["ask_question", "run_block"])
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
|
||||
assert inp.tools == ["run_block"]
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.tools == ["run_block"]
|
||||
|
||||
async def test_valid_block_name_accepted(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
"""Shared provider config for Ayrshare social-media blocks.
|
||||
|
||||
The "credential" exposed to blocks is the **per-user Ayrshare profile key**,
|
||||
not the org-level ``AYRSHARE_API_KEY``. Profile keys are provisioned per
|
||||
user by :class:`~backend.integrations.managed_providers.ayrshare.AyrshareManagedProvider`
|
||||
and stored in the normal credentials list with ``is_managed=True``, so every
|
||||
Ayrshare block fits the standard credential flow:
|
||||
|
||||
credentials: CredentialsMetaInput = ayrshare.credentials_field(...)
|
||||
|
||||
``run_block`` / ``resolve_block_credentials`` take care of the rest.
|
||||
|
||||
``with_managed_api_key()`` registers ``api_key`` as a supported auth type
|
||||
without the env-var-backed default credential that ``with_api_key()`` would
|
||||
create — the org-level ``AYRSHARE_API_KEY`` is the admin key and must never
|
||||
reach a block as a "profile key".
|
||||
"""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
ayrshare = (
|
||||
ProviderBuilder("ayrshare")
|
||||
.with_description("Post to every social network")
|
||||
.with_managed_api_key()
|
||||
.build()
|
||||
)
|
||||
@@ -1,18 +0,0 @@
|
||||
from backend.sdk import BlockCost, BlockCostType
|
||||
|
||||
# Ayrshare is a subscription proxy ($149/mo Business). Per-post credit charges
|
||||
# prevent a single heavy user from absorbing the fixed cost and align with the
|
||||
# upload cost of each post variant.
|
||||
# cost_filter matches on input_data.is_video BEFORE run() executes, so the flag
|
||||
# has to be correct at input-eval time. Video-only platforms (YouTube, Snapchat)
|
||||
# override the base default to True; platforms that accept both (TikTok, etc.)
|
||||
# rely on the caller setting is_video explicitly for accurate billing.
|
||||
# First match wins in block_usage_cost, so list the video tier first.
|
||||
AYRSHARE_POST_COSTS = (
|
||||
BlockCost(
|
||||
cost_amount=5, cost_type=BlockCostType.RUN, cost_filter={"is_video": True}
|
||||
),
|
||||
BlockCost(
|
||||
cost_amount=2, cost_type=BlockCostType.RUN, cost_filter={"is_video": False}
|
||||
),
|
||||
)
|
||||
@@ -4,25 +4,22 @@ from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.blocks._base import BlockSchemaInput
|
||||
from backend.data.model import CredentialsMetaInput, SchemaField
|
||||
from backend.data.model import SchemaField, UserIntegrations
|
||||
from backend.integrations.ayrshare import AyrshareClient
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
|
||||
from ._config import ayrshare
|
||||
|
||||
async def get_profile_key(user_id: str):
|
||||
user_integrations: UserIntegrations = (
|
||||
await get_database_manager_async_client().get_user_integrations(user_id)
|
||||
)
|
||||
return user_integrations.managed_credentials.ayrshare_profile_key
|
||||
|
||||
|
||||
class BaseAyrshareInput(BlockSchemaInput):
|
||||
"""Base input model for Ayrshare social media posts with common fields."""
|
||||
|
||||
credentials: CredentialsMetaInput = ayrshare.credentials_field(
|
||||
description=(
|
||||
"Ayrshare profile credential. AutoGPT provisions this managed "
|
||||
"credential automatically — the user does not create it. After "
|
||||
"it's in place, the user links each social account via the "
|
||||
"Ayrshare SSO popup in the Builder."
|
||||
),
|
||||
)
|
||||
|
||||
post: str = SchemaField(
|
||||
description="The post text to be published", default="", advanced=False
|
||||
)
|
||||
@@ -32,9 +29,7 @@ class BaseAyrshareInput(BlockSchemaInput):
|
||||
advanced=False,
|
||||
)
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video. Set to True when uploading a video so billing applies the video tier.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
description="Whether the media is a video", default=False, advanced=True
|
||||
)
|
||||
schedule_date: Optional[datetime] = SchemaField(
|
||||
description="UTC datetime for scheduling (YYYY-MM-DDThh:mm:ssZ)",
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToBlueskyBlock(Block):
|
||||
"""Block for posting to Bluesky with Bluesky-specific options."""
|
||||
|
||||
@@ -61,10 +57,16 @@ class PostToBlueskyBlock(Block):
|
||||
self,
|
||||
input_data: "PostToBlueskyBlock.Input",
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Bluesky with Bluesky-specific options."""
|
||||
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -104,7 +106,7 @@ class PostToBlueskyBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
bluesky_options=bluesky_options if bluesky_options else None,
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, CarouselItem, create_ayrshare_client
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
CarouselItem,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToFacebookBlock(Block):
|
||||
"""Block for posting to Facebook with Facebook-specific options."""
|
||||
|
||||
@@ -119,10 +120,15 @@ class PostToFacebookBlock(Block):
|
||||
self,
|
||||
input_data: "PostToFacebookBlock.Input",
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Facebook with Facebook-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -198,7 +204,7 @@ class PostToFacebookBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
facebook_options=facebook_options if facebook_options else None,
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToGMBBlock(Block):
|
||||
"""Block for posting to Google My Business with GMB-specific options."""
|
||||
|
||||
@@ -114,13 +110,14 @@ class PostToGMBBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToGMBBlock.Input",
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
**kwargs
|
||||
self, input_data: "PostToGMBBlock.Input", *, user_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""Post to Google My Business with GMB-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -205,7 +202,7 @@ class PostToGMBBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
gmb_options=gmb_options if gmb_options else None,
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -2,21 +2,22 @@ from typing import Any
|
||||
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, InstagramUserTag, create_ayrshare_client
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
InstagramUserTag,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToInstagramBlock(Block):
|
||||
"""Block for posting to Instagram with Instagram-specific options."""
|
||||
|
||||
@@ -111,10 +112,15 @@ class PostToInstagramBlock(Block):
|
||||
self,
|
||||
input_data: "PostToInstagramBlock.Input",
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Instagram with Instagram-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -235,7 +241,7 @@ class PostToInstagramBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
instagram_options=instagram_options if instagram_options else None,
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToLinkedInBlock(Block):
|
||||
"""Block for posting to LinkedIn with LinkedIn-specific options."""
|
||||
|
||||
@@ -116,10 +112,15 @@ class PostToLinkedInBlock(Block):
|
||||
self,
|
||||
input_data: "PostToLinkedInBlock.Input",
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to LinkedIn with LinkedIn-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -213,7 +214,7 @@ class PostToLinkedInBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
linkedin_options=linkedin_options if linkedin_options else None,
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, PinterestCarouselOption, create_ayrshare_client
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
PinterestCarouselOption,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToPinterestBlock(Block):
|
||||
"""Block for posting to Pinterest with Pinterest-specific options."""
|
||||
|
||||
@@ -91,10 +92,15 @@ class PostToPinterestBlock(Block):
|
||||
self,
|
||||
input_data: "PostToPinterestBlock.Input",
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Pinterest with Pinterest-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -200,7 +206,7 @@ class PostToPinterestBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
pinterest_options=pinterest_options if pinterest_options else None,
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToRedditBlock(Block):
|
||||
"""Block for posting to Reddit."""
|
||||
|
||||
@@ -39,12 +35,12 @@ class PostToRedditBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToRedditBlock.Input",
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
**kwargs
|
||||
self, input_data: "PostToRedditBlock.Input", *, user_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured."
|
||||
@@ -65,7 +61,7 @@ class PostToRedditBlock(Block):
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToSnapchatBlock(Block):
|
||||
"""Block for posting to Snapchat with Snapchat-specific options."""
|
||||
|
||||
@@ -35,14 +31,6 @@ class PostToSnapchatBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Snapchat is video-only; override the base default so the @cost filter
|
||||
# selects the 5-credit video tier instead of the 2-credit image tier.
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video (always True for Snapchat)",
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Snapchat-specific options
|
||||
story_type: str = SchemaField(
|
||||
description="Type of Snapchat content: 'story' (24-hour Stories), 'saved_story' (Saved Stories), or 'spotlight' (Spotlight posts)",
|
||||
@@ -74,10 +62,15 @@ class PostToSnapchatBlock(Block):
|
||||
self,
|
||||
input_data: "PostToSnapchatBlock.Input",
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Snapchat with Snapchat-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -128,7 +121,7 @@ class PostToSnapchatBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
snapchat_options=snapchat_options if snapchat_options else None,
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToTelegramBlock(Block):
|
||||
"""Block for posting to Telegram with Telegram-specific options."""
|
||||
|
||||
@@ -61,10 +57,15 @@ class PostToTelegramBlock(Block):
|
||||
self,
|
||||
input_data: "PostToTelegramBlock.Input",
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Telegram with Telegram-specific validation."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -107,7 +108,7 @@ class PostToTelegramBlock(Block):
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToThreadsBlock(Block):
|
||||
"""Block for posting to Threads with Threads-specific options."""
|
||||
|
||||
@@ -54,10 +50,15 @@ class PostToThreadsBlock(Block):
|
||||
self,
|
||||
input_data: "PostToThreadsBlock.Input",
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Threads with Threads-specific validation."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -102,7 +103,7 @@ class PostToThreadsBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
threads_options=threads_options if threads_options else None,
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -2,18 +2,15 @@ from enum import Enum
|
||||
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class TikTokVisibility(str, Enum):
|
||||
@@ -22,7 +19,6 @@ class TikTokVisibility(str, Enum):
|
||||
FOLLOWERS = "followers"
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToTikTokBlock(Block):
|
||||
"""Block for posting to TikTok with TikTok-specific options."""
|
||||
|
||||
@@ -117,13 +113,14 @@ class PostToTikTokBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToTikTokBlock.Input",
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
self, input_data: "PostToTikTokBlock.Input", *, user_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""Post to TikTok with TikTok-specific validation and options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -238,7 +235,7 @@ class PostToTikTokBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
tiktok_options=tiktok_options if tiktok_options else None,
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToXBlock(Block):
|
||||
"""Block for posting to X / Twitter with Twitter-specific options."""
|
||||
|
||||
@@ -119,10 +115,15 @@ class PostToXBlock(Block):
|
||||
self,
|
||||
input_data: "PostToXBlock.Input",
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to X / Twitter with enhanced X-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -232,7 +233,7 @@ class PostToXBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
twitter_options=twitter_options if twitter_options else None,
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -3,18 +3,15 @@ from typing import Any
|
||||
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._cost import AYRSHARE_POST_COSTS
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class YouTubeVisibility(str, Enum):
|
||||
@@ -23,7 +20,6 @@ class YouTubeVisibility(str, Enum):
|
||||
UNLISTED = "unlisted"
|
||||
|
||||
|
||||
@cost(*AYRSHARE_POST_COSTS)
|
||||
class PostToYouTubeBlock(Block):
|
||||
"""Block for posting to YouTube with YouTube-specific options."""
|
||||
|
||||
@@ -43,14 +39,6 @@ class PostToYouTubeBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# YouTube is video-only; override the base default so the @cost filter
|
||||
# selects the 5-credit video tier instead of the 2-credit image tier.
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video (always True for YouTube)",
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# YouTube-specific required options
|
||||
title: str = SchemaField(
|
||||
description="Video title (max 100 chars, required). Cannot contain < or > characters.",
|
||||
@@ -149,10 +137,16 @@ class PostToYouTubeBlock(Block):
|
||||
self,
|
||||
input_data: "PostToYouTubeBlock.Input",
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to YouTube with YouTube-specific validation and options."""
|
||||
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
@@ -308,7 +302,7 @@ class PostToYouTubeBlock(Block):
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
youtube_options=youtube_options,
|
||||
profile_key=credentials.api_key.get_secret_value(),
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
|
||||
@@ -7,7 +7,6 @@ from backend.sdk import BlockCostType, ProviderBuilder
|
||||
# Configure the Meeting BaaS provider with API key authentication
|
||||
baas = (
|
||||
ProviderBuilder("baas")
|
||||
.with_description("Meeting recording and transcription")
|
||||
.with_api_key("MEETING_BAAS_API_KEY", "Meeting BaaS API Key")
|
||||
.with_base_cost(5, BlockCostType.RUN) # Higher cost for meeting recording service
|
||||
.build()
|
||||
|
||||
@@ -4,34 +4,21 @@ Meeting BaaS bot (recording) blocks.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockCost,
|
||||
BlockCostType,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._api import MeetingBaasAPI
|
||||
from ._config import baas
|
||||
|
||||
# Meeting BaaS recording rate: $0.69 per hour.
|
||||
_MEETING_BAAS_USD_PER_SECOND = 0.69 / 3600
|
||||
|
||||
# Join bills a flat 30 cr commit (covers median short meeting);
|
||||
# FetchMeetingData bills the duration-scaled remainder from the
|
||||
# `duration_seconds` field on the API response. Long meetings no
|
||||
# longer under-bill.
|
||||
|
||||
|
||||
@cost(BlockCost(cost_type=BlockCostType.RUN, cost_amount=30))
|
||||
class BaasBotJoinMeetingBlock(Block):
|
||||
"""
|
||||
Deploy a bot immediately or at a scheduled start_time to join and record a meeting.
|
||||
@@ -147,7 +134,6 @@ class BaasBotLeaveMeetingBlock(Block):
|
||||
yield "left", left
|
||||
|
||||
|
||||
@cost(BlockCost(cost_type=BlockCostType.COST_USD, cost_amount=150))
|
||||
class BaasBotFetchMeetingDataBlock(Block):
|
||||
"""
|
||||
Pull MP4 URL, transcript & metadata for a completed meeting.
|
||||
@@ -190,21 +176,9 @@ class BaasBotFetchMeetingDataBlock(Block):
|
||||
include_transcripts=input_data.include_transcripts,
|
||||
)
|
||||
|
||||
bot_meta = data.get("bot_data", {}).get("bot", {}) or {}
|
||||
# Bill recording duration via COST_USD so multi-hour meetings
|
||||
# scale past the Join block's flat 30 cr deposit.
|
||||
duration_seconds = float(bot_meta.get("duration_seconds") or 0)
|
||||
if duration_seconds > 0:
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=duration_seconds * _MEETING_BAAS_USD_PER_SECOND,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
|
||||
yield "mp4_url", data.get("mp4", "")
|
||||
yield "transcript", data.get("bot_data", {}).get("transcripts", [])
|
||||
yield "metadata", bot_meta
|
||||
yield "metadata", data.get("bot_data", {}).get("bot", {})
|
||||
|
||||
|
||||
class BaasBotDeleteRecordingBlock(Block):
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
"""Unit tests for Meeting BaaS duration-based cost emission."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.baas.bots import (
|
||||
_MEETING_BAAS_USD_PER_SECOND,
|
||||
BaasBotFetchMeetingDataBlock,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="baas",
|
||||
title="Mock BaaS API Key",
|
||||
api_key=SecretStr("mock-baas-api-key"),
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
|
||||
def test_usd_per_second_derives_from_published_rate():
|
||||
"""$0.69/hour published rate → ~$0.000192/second."""
|
||||
assert _MEETING_BAAS_USD_PER_SECOND == pytest.approx(0.69 / 3600)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"duration_seconds, expected_usd",
|
||||
[
|
||||
(3600, 0.69), # 1 hour
|
||||
(1800, 0.345), # 30 min
|
||||
(0, None), # no recording → no emission
|
||||
(None, None), # missing duration field → no emission
|
||||
],
|
||||
)
|
||||
async def test_fetch_meeting_data_emits_duration_cost_usd(
|
||||
duration_seconds, expected_usd
|
||||
):
|
||||
"""FetchMeetingData extracts duration_seconds from bot metadata and
|
||||
emits provider_cost / cost_usd scaled by the published $0.69/hr rate.
|
||||
Emission is skipped when duration is 0 or missing.
|
||||
"""
|
||||
block = BaasBotFetchMeetingDataBlock()
|
||||
|
||||
bot_meta = {"id": "bot-xyz"}
|
||||
if duration_seconds is not None:
|
||||
bot_meta["duration_seconds"] = duration_seconds
|
||||
|
||||
mock_api = AsyncMock()
|
||||
mock_api.get_meeting_data.return_value = {
|
||||
"mp4": "https://example/recording.mp4",
|
||||
"bot_data": {"bot": bot_meta, "transcripts": []},
|
||||
}
|
||||
|
||||
captured: list[NodeExecutionStats] = []
|
||||
with (
|
||||
patch("backend.blocks.baas.bots.MeetingBaasAPI", return_value=mock_api),
|
||||
patch.object(block, "merge_stats", side_effect=captured.append),
|
||||
):
|
||||
outputs = []
|
||||
async for name, val in block.run(
|
||||
block.input_schema(
|
||||
credentials={
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
},
|
||||
bot_id="bot-xyz",
|
||||
include_transcripts=False,
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
outputs.append((name, val))
|
||||
|
||||
# Always yields the 3 outputs regardless of duration.
|
||||
names = [n for n, _ in outputs]
|
||||
assert "mp4_url" in names and "metadata" in names
|
||||
|
||||
if expected_usd is None:
|
||||
assert captured == []
|
||||
else:
|
||||
assert len(captured) == 1
|
||||
assert captured[0].provider_cost == pytest.approx(expected_usd)
|
||||
assert captured[0].provider_cost_type == "cost_usd"
|
||||
@@ -2,8 +2,7 @@ from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
bannerbear = (
|
||||
ProviderBuilder("bannerbear")
|
||||
.with_description("Auto-generate images and videos")
|
||||
.with_api_key("BANNERBEAR_API_KEY", "Bannerbear API Key")
|
||||
.with_base_cost(3, BlockCostType.RUN)
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -433,7 +433,7 @@ class TestJinaEmbeddingBlockCostTracking:
|
||||
class TestUnrealTextToSpeechBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_character_count(self):
|
||||
"""provider_cost = len(text) * $0.000016 with type='cost_usd'."""
|
||||
"""provider_cost equals len(text) with type='characters'."""
|
||||
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
|
||||
from backend.blocks.text_to_speech_block import (
|
||||
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
|
||||
@@ -461,12 +461,12 @@ class TestUnrealTextToSpeechBlockCostTracking:
|
||||
|
||||
mock_merge.assert_called_once()
|
||||
stats = mock_merge.call_args[0][0]
|
||||
assert stats.provider_cost == pytest.approx(len(test_text) * 0.000016)
|
||||
assert stats.provider_cost_type == "cost_usd"
|
||||
assert stats.provider_cost == float(len(test_text))
|
||||
assert stats.provider_cost_type == "characters"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_text_gives_zero_characters(self):
|
||||
"""An empty text string results in provider_cost=0.0 (cost_usd)."""
|
||||
"""An empty text string results in provider_cost=0.0."""
|
||||
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
|
||||
from backend.blocks.text_to_speech_block import (
|
||||
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
|
||||
@@ -494,7 +494,7 @@ class TestUnrealTextToSpeechBlockCostTracking:
|
||||
mock_merge.assert_called_once()
|
||||
stats = mock_merge.call_args[0][0]
|
||||
assert stats.provider_cost == 0.0
|
||||
assert stats.provider_cost_type == "cost_usd"
|
||||
assert stats.provider_cost_type == "characters"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -17,7 +17,6 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -432,7 +431,6 @@ class ClaudeCodeBlock(Block):
|
||||
# The JSON output contains the result
|
||||
output_data = json.loads(raw_output)
|
||||
response = output_data.get("result", raw_output)
|
||||
self._record_cli_cost(output_data)
|
||||
|
||||
# Build conversation history entry
|
||||
turn_entry = f"User: {prompt}\nClaude: {response}"
|
||||
@@ -486,23 +484,6 @@ class ClaudeCodeBlock(Block):
|
||||
escaped = prompt.replace("'", "'\"'\"'")
|
||||
return f"'{escaped}'"
|
||||
|
||||
def _record_cli_cost(self, output_data: dict) -> None:
|
||||
"""Feed Claude Code CLI's `total_cost_usd` to the COST_USD resolver.
|
||||
|
||||
The CLI rolls up Anthropic LLM + internal tool-call spend into
|
||||
``total_cost_usd`` on its JSON response; piping it through
|
||||
``merge_stats`` lets the wallet reflect real spend.
|
||||
"""
|
||||
total_cost_usd = output_data.get("total_cost_usd")
|
||||
if total_cost_usd is None:
|
||||
return
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(total_cost_usd),
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
"""Unit tests for ClaudeCodeBlock COST_USD billing migration.
|
||||
|
||||
Verifies:
|
||||
- Block emits provider_cost / cost_usd when Claude Code CLI returns
|
||||
total_cost_usd.
|
||||
- block_usage_cost resolves the COST_USD entry to the expected ceil(usd *
|
||||
cost_amount) credit charge.
|
||||
- Missing total_cost_usd gracefully produces provider_cost=None (no bill).
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks._base import BlockCostType
|
||||
from backend.blocks.claude_code import ClaudeCodeBlock
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.executor.utils import block_usage_cost
|
||||
|
||||
|
||||
def test_claude_code_registered_as_cost_usd_150():
|
||||
"""Sanity: BLOCK_COSTS holds the COST_USD, 150 cr/$ entry."""
|
||||
entries = BLOCK_COSTS[ClaudeCodeBlock]
|
||||
assert len(entries) == 1
|
||||
entry = entries[0]
|
||||
assert entry.cost_type == BlockCostType.COST_USD
|
||||
assert entry.cost_amount == 150
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"total_cost_usd, expected_credits",
|
||||
[
|
||||
(0.50, 75), # $0.50 × 150 = 75 cr
|
||||
(1.00, 150), # $1.00 × 150 = 150 cr
|
||||
(0.0134, 3), # ceil(0.0134 × 150) = ceil(2.01) = 3
|
||||
(2.00, 300), # $2 × 150 = 300 cr
|
||||
(0.001, 1), # ceil(0.001 × 150) = ceil(0.15) = 1 — no 0-cr leak on
|
||||
# sub-cent runs
|
||||
],
|
||||
)
|
||||
def test_cost_usd_resolver_applies_150_multiplier(total_cost_usd, expected_credits):
|
||||
"""block_usage_cost with cost_usd stats returns ceil(usd * 150)."""
|
||||
block = ClaudeCodeBlock()
|
||||
# cost_filter requires matching e2b_credentials; supply the ones the
|
||||
# registration uses so _is_cost_filter_match accepts the input.
|
||||
entry = BLOCK_COSTS[ClaudeCodeBlock][0]
|
||||
input_data = {"e2b_credentials": entry.cost_filter["e2b_credentials"]}
|
||||
stats = NodeExecutionStats(
|
||||
provider_cost=total_cost_usd,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
cost, matching_filter = block_usage_cost(
|
||||
block=block, input_data=input_data, stats=stats
|
||||
)
|
||||
assert cost == expected_credits
|
||||
assert matching_filter == entry.cost_filter
|
||||
|
||||
|
||||
def test_cost_usd_resolver_returns_zero_when_stats_missing_cost():
|
||||
"""Pre-flight (no stats) or unbilled run (provider_cost None) → 0."""
|
||||
block = ClaudeCodeBlock()
|
||||
entry = BLOCK_COSTS[ClaudeCodeBlock][0]
|
||||
input_data = {"e2b_credentials": entry.cost_filter["e2b_credentials"]}
|
||||
# No stats at all → pre-flight path, returns 0.
|
||||
pre_cost, _ = block_usage_cost(block=block, input_data=input_data)
|
||||
assert pre_cost == 0
|
||||
# Stats present but no provider_cost → resolver can't bill.
|
||||
stats = NodeExecutionStats()
|
||||
post_cost, _ = block_usage_cost(block=block, input_data=input_data, stats=stats)
|
||||
assert post_cost == 0
|
||||
|
||||
|
||||
def test_record_cli_cost_emits_provider_cost_when_total_cost_present():
|
||||
"""``_record_cli_cost`` (the helper called from ``execute_claude_code``)
|
||||
must emit a single ``merge_stats`` with provider_cost + cost_usd tag
|
||||
when the CLI JSON payload carries ``total_cost_usd``.
|
||||
"""
|
||||
block = ClaudeCodeBlock()
|
||||
captured: list[NodeExecutionStats] = []
|
||||
with patch.object(block, "merge_stats", side_effect=captured.append):
|
||||
block._record_cli_cost(
|
||||
{
|
||||
"result": "hello from claude",
|
||||
"total_cost_usd": 0.0421,
|
||||
"usage": {"input_tokens": 1234, "output_tokens": 56},
|
||||
}
|
||||
)
|
||||
|
||||
assert len(captured) == 1
|
||||
stats = captured[0]
|
||||
assert stats.provider_cost == pytest.approx(0.0421)
|
||||
assert stats.provider_cost_type == "cost_usd"
|
||||
|
||||
|
||||
def test_record_cli_cost_skips_merge_when_total_cost_absent():
|
||||
"""If the CLI payload lacks ``total_cost_usd`` (legacy / non-JSON
|
||||
output), ``_record_cli_cost`` must not call ``merge_stats`` — otherwise
|
||||
we'd pollute telemetry with a ``cost_usd`` emission that has no real
|
||||
cost attached.
|
||||
"""
|
||||
block = ClaudeCodeBlock()
|
||||
mock = MagicMock()
|
||||
with patch.object(block, "merge_stats", mock):
|
||||
block._record_cli_cost({"result": "hello"})
|
||||
mock.assert_not_called()
|
||||
@@ -151,17 +151,6 @@ class CodeGenerationBlock(Block):
|
||||
)
|
||||
self.execution_stats = NodeExecutionStats()
|
||||
|
||||
# GPT-5.1-Codex published pricing: $1.25 / 1M input, $10 / 1M output.
|
||||
_INPUT_USD_PER_1M = 1.25
|
||||
_OUTPUT_USD_PER_1M = 10.0
|
||||
|
||||
@staticmethod
|
||||
def _compute_token_usd(input_tokens: int, output_tokens: int) -> float:
|
||||
return (
|
||||
input_tokens * CodeGenerationBlock._INPUT_USD_PER_1M
|
||||
+ output_tokens * CodeGenerationBlock._OUTPUT_USD_PER_1M
|
||||
) / 1_000_000
|
||||
|
||||
async def call_codex(
|
||||
self,
|
||||
*,
|
||||
@@ -200,15 +189,13 @@ class CodeGenerationBlock(Block):
|
||||
response_id = response.id or ""
|
||||
|
||||
# Update usage stats
|
||||
input_tokens = response.usage.input_tokens if response.usage else 0
|
||||
output_tokens = response.usage.output_tokens if response.usage else 0
|
||||
self.execution_stats.input_token_count = input_tokens
|
||||
self.execution_stats.output_token_count = output_tokens
|
||||
self.execution_stats.llm_call_count += 1
|
||||
self.execution_stats.provider_cost = self._compute_token_usd(
|
||||
input_tokens, output_tokens
|
||||
self.execution_stats.input_token_count = (
|
||||
response.usage.input_tokens if response.usage else 0
|
||||
)
|
||||
self.execution_stats.provider_cost_type = "cost_usd"
|
||||
self.execution_stats.output_token_count = (
|
||||
response.usage.output_tokens if response.usage else 0
|
||||
)
|
||||
self.execution_stats.llm_call_count += 1
|
||||
|
||||
return CodexCallResult(
|
||||
response=text_output,
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
"""Provider registration for Compass — metadata only (auth lives elsewhere)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
compass = (
|
||||
ProviderBuilder("compass")
|
||||
.with_description("Geospatial context for agents")
|
||||
.with_supported_auth_types("api_key")
|
||||
.build()
|
||||
)
|
||||
@@ -1,226 +0,0 @@
|
||||
"""Coverage tests for the cost-leak fixes in this PR.
|
||||
|
||||
Each block's ``run()`` / helper emits provider_cost + cost_usd (or items)
|
||||
via merge_stats so the post-flight resolver bills real provider spend.
|
||||
Tests here drive that emission path directly so a regression on any one
|
||||
block surfaces immediately.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks._base import BlockCostType
|
||||
from backend.blocks.ai_condition import AIConditionBlock
|
||||
from backend.data.block_cost_config import BLOCK_COSTS, LLM_COST
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats
|
||||
|
||||
# -------- AIConditionBlock registration --------
|
||||
|
||||
|
||||
def test_ai_condition_registered_under_llm_cost():
|
||||
"""AIConditionBlock was running wallet-free before this PR; verify it
|
||||
now resolves through the same per-model LLM_COST table as every other
|
||||
LLM block.
|
||||
"""
|
||||
assert BLOCK_COSTS[AIConditionBlock] is LLM_COST
|
||||
|
||||
|
||||
# -------- Pinecone insert ITEMS emission --------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pinecone_insert_emits_items_provider_cost():
|
||||
from backend.blocks.pinecone import PineconeInsertBlock
|
||||
|
||||
block = PineconeInsertBlock()
|
||||
captured: list[NodeExecutionStats] = []
|
||||
|
||||
class _FakeIndex:
|
||||
def upsert(self, **_):
|
||||
return None
|
||||
|
||||
class _FakePinecone:
|
||||
def __init__(self, *_, **__):
|
||||
pass
|
||||
|
||||
def Index(self, _name):
|
||||
return _FakeIndex()
|
||||
|
||||
with (
|
||||
patch("backend.blocks.pinecone.Pinecone", _FakePinecone),
|
||||
patch.object(block, "merge_stats", side_effect=captured.append),
|
||||
):
|
||||
input_data = block.input_schema(
|
||||
credentials={
|
||||
"id": "00000000-0000-0000-0000-000000000000",
|
||||
"provider": "pinecone",
|
||||
"type": "api_key",
|
||||
},
|
||||
index="my-index",
|
||||
chunks=["alpha", "beta", "gamma"],
|
||||
embeddings=[[0.1] * 4, [0.2] * 4, [0.3] * 4],
|
||||
namespace="",
|
||||
metadata={},
|
||||
)
|
||||
|
||||
creds = APIKeyCredentials(
|
||||
id="00000000-0000-0000-0000-000000000000",
|
||||
provider="pinecone",
|
||||
title="mock",
|
||||
api_key=SecretStr("mock-key"),
|
||||
expires_at=None,
|
||||
)
|
||||
outputs = [(n, v) async for n, v in block.run(input_data, credentials=creds)]
|
||||
|
||||
assert any(name == "upsert_response" for name, _ in outputs)
|
||||
assert len(captured) == 1
|
||||
stats = captured[0]
|
||||
assert stats.provider_cost == pytest.approx(3.0)
|
||||
assert stats.provider_cost_type == "items"
|
||||
|
||||
|
||||
# -------- Narration model-aware per-char rate --------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_id, expected_rate_per_char",
|
||||
[
|
||||
("eleven_flash_v2_5", 0.000167 * 0.5),
|
||||
("eleven_turbo_v2_5", 0.000167 * 0.5),
|
||||
("eleven_multilingual_v2", 0.000167 * 1.0),
|
||||
("eleven_turbo_v2", 0.000167 * 1.0),
|
||||
],
|
||||
)
|
||||
def test_narration_per_char_rate_scales_with_model(model_id, expected_rate_per_char):
|
||||
"""Drive VideoNarrationBlock._record_script_cost directly so a regression
|
||||
that drops the model-aware branching (e.g. hardcoding 1.0 cr/char for
|
||||
all models) makes this test fail.
|
||||
"""
|
||||
from backend.blocks.video.narration import VideoNarrationBlock
|
||||
|
||||
block = VideoNarrationBlock()
|
||||
captured: list[NodeExecutionStats] = []
|
||||
with patch.object(block, "merge_stats", side_effect=captured.append):
|
||||
block._record_script_cost("x" * 5000, model_id)
|
||||
|
||||
assert len(captured) == 1
|
||||
stats = captured[0]
|
||||
assert stats.provider_cost == pytest.approx(5000 * expected_rate_per_char)
|
||||
assert stats.provider_cost_type == "cost_usd"
|
||||
|
||||
|
||||
# -------- Perplexity None-guard on x-total-cost --------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"openrouter_cost, expect_type",
|
||||
[
|
||||
(0.0421, "cost_usd"), # concrete positive USD → tagged
|
||||
(None, None), # header missing → no tag (keeps gap observable)
|
||||
(0.0, None), # zero → no tag (wouldn't bill anything anyway)
|
||||
],
|
||||
)
|
||||
def test_perplexity_record_openrouter_cost_tags_only_on_concrete_value(
|
||||
openrouter_cost, expect_type
|
||||
):
|
||||
"""Drive PerplexityBlock._record_openrouter_cost directly to verify the
|
||||
None/0 guard. A regression that tags cost_usd unconditionally would
|
||||
silently floor the user's bill to 0 via the resolver — this test
|
||||
would catch it.
|
||||
"""
|
||||
from backend.blocks.perplexity import PerplexityBlock
|
||||
|
||||
block = PerplexityBlock()
|
||||
with patch(
|
||||
"backend.blocks.perplexity.extract_openrouter_cost",
|
||||
return_value=openrouter_cost,
|
||||
):
|
||||
block._record_openrouter_cost(response=object())
|
||||
|
||||
assert block.execution_stats.provider_cost == openrouter_cost
|
||||
assert block.execution_stats.provider_cost_type == expect_type
|
||||
|
||||
|
||||
# -------- Codex COST_USD registration --------
|
||||
|
||||
|
||||
def test_codex_registered_as_cost_usd_150():
|
||||
from backend.blocks.codex import CodeGenerationBlock
|
||||
|
||||
entries = BLOCK_COSTS[CodeGenerationBlock]
|
||||
assert len(entries) == 1
|
||||
entry = entries[0]
|
||||
assert entry.cost_type == BlockCostType.COST_USD
|
||||
assert entry.cost_amount == 150
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_tokens, output_tokens, expected_usd",
|
||||
[
|
||||
# GPT-5.1-Codex: $1.25 / 1M input, $10 / 1M output.
|
||||
(1_000_000, 0, 1.25),
|
||||
(0, 1_000_000, 10.0),
|
||||
(100_000, 10_000, 0.225), # 0.125 + 0.100
|
||||
(0, 0, 0.0),
|
||||
],
|
||||
)
|
||||
def test_codex_computes_provider_cost_usd_from_token_counts(
|
||||
input_tokens, output_tokens, expected_usd
|
||||
):
|
||||
"""Drive CodeGenerationBlock._compute_token_usd directly. A regression
|
||||
to the wrong rate constants (e.g. swapping the $1.25 input rate for
|
||||
GPT-4o's $2.50) would fail this test.
|
||||
"""
|
||||
from backend.blocks.codex import CodeGenerationBlock
|
||||
|
||||
assert CodeGenerationBlock._compute_token_usd(
|
||||
input_tokens, output_tokens
|
||||
) == pytest.approx(expected_usd)
|
||||
|
||||
|
||||
# -------- ClaudeCode COST_USD registration sanity (already tested in claude_code_cost_test.py) --------
|
||||
|
||||
|
||||
# -------- Perplexity COST_USD registration for all 3 tiers --------
|
||||
|
||||
|
||||
def test_perplexity_sonar_all_tiers_registered_as_cost_usd_150():
|
||||
from backend.blocks.perplexity import PerplexityBlock
|
||||
|
||||
entries = BLOCK_COSTS[PerplexityBlock]
|
||||
# 3 tiers (SONAR, SONAR_PRO, SONAR_DEEP_RESEARCH) all COST_USD 150.
|
||||
assert len(entries) == 3
|
||||
for entry in entries:
|
||||
assert entry.cost_type == BlockCostType.COST_USD
|
||||
assert entry.cost_amount == 150
|
||||
|
||||
|
||||
# -------- Narration COST_USD registration --------
|
||||
|
||||
|
||||
def test_narration_registered_as_cost_usd_150():
|
||||
from backend.blocks.video.narration import VideoNarrationBlock
|
||||
|
||||
entries = BLOCK_COSTS[VideoNarrationBlock]
|
||||
assert len(entries) == 1
|
||||
assert entries[0].cost_type == BlockCostType.COST_USD
|
||||
assert entries[0].cost_amount == 150
|
||||
|
||||
|
||||
# -------- Pinecone registrations --------
|
||||
|
||||
|
||||
def test_pinecone_registrations():
|
||||
from backend.blocks.pinecone import (
|
||||
PineconeInitBlock,
|
||||
PineconeInsertBlock,
|
||||
PineconeQueryBlock,
|
||||
)
|
||||
|
||||
assert BLOCK_COSTS[PineconeInitBlock][0].cost_type == BlockCostType.RUN
|
||||
assert BLOCK_COSTS[PineconeQueryBlock][0].cost_type == BlockCostType.RUN
|
||||
# Insert scales with item count.
|
||||
assert BLOCK_COSTS[PineconeInsertBlock][0].cost_type == BlockCostType.ITEMS
|
||||
assert BLOCK_COSTS[PineconeInsertBlock][0].cost_amount == 1
|
||||
@@ -19,10 +19,6 @@ class DataForSeoClient:
|
||||
trusted_origins=["https://api.dataforseo.com"],
|
||||
raise_for_status=False,
|
||||
)
|
||||
# USD cost reported by DataForSEO on the most recent successful call.
|
||||
# Populated by keyword_suggestions / related_keywords so the caller
|
||||
# can surface it via NodeExecutionStats.provider_cost for billing.
|
||||
self.last_cost_usd: float = 0.0
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""Generate the authorization header using Basic Auth."""
|
||||
@@ -101,9 +97,6 @@ class DataForSeoClient:
|
||||
if data.get("tasks") and len(data["tasks"]) > 0:
|
||||
task = data["tasks"][0]
|
||||
if task.get("status_code") == 20000: # Success code
|
||||
# DataForSEO reports per-task USD cost; stash it so callers
|
||||
# can populate NodeExecutionStats.provider_cost.
|
||||
self.last_cost_usd = float(task.get("cost") or 0.0)
|
||||
return task.get("result", [])
|
||||
else:
|
||||
error_msg = task.get("status_message", "Task failed")
|
||||
@@ -181,9 +174,6 @@ class DataForSeoClient:
|
||||
if data.get("tasks") and len(data["tasks"]) > 0:
|
||||
task = data["tasks"][0]
|
||||
if task.get("status_code") == 20000: # Success code
|
||||
# DataForSEO reports per-task USD cost; stash it so callers
|
||||
# can populate NodeExecutionStats.provider_cost.
|
||||
self.last_cost_usd = float(task.get("cost") or 0.0)
|
||||
return task.get("result", [])
|
||||
else:
|
||||
error_msg = task.get("status_message", "Task failed")
|
||||
|
||||
@@ -7,17 +7,11 @@ from backend.sdk import BlockCostType, ProviderBuilder
|
||||
# Build the DataForSEO provider with username/password authentication
|
||||
dataforseo = (
|
||||
ProviderBuilder("dataforseo")
|
||||
.with_description("SEO and SERP data")
|
||||
.with_user_password(
|
||||
username_env_var="DATAFORSEO_USERNAME",
|
||||
password_env_var="DATAFORSEO_PASSWORD",
|
||||
title="DataForSEO Credentials",
|
||||
)
|
||||
# DataForSEO reports USD cost per task (e.g. $0.001/keyword returned).
|
||||
# DataForSeoClient stashes it on last_cost_usd; each block emits it via
|
||||
# merge_stats so the COST_USD resolver bills against real spend.
|
||||
# 1000 platform credits per USD → 1 credit per $0.001 (≈ 1 credit/
|
||||
# returned keyword on the standard tier).
|
||||
.with_base_cost(1000, BlockCostType.COST_USD)
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ DataForSEO Google Keyword Suggestions block.
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -111,10 +110,8 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
test_output=[
|
||||
(
|
||||
"suggestion",
|
||||
lambda x: (
|
||||
hasattr(x, "keyword")
|
||||
and x.keyword == "digital marketing strategy"
|
||||
),
|
||||
lambda x: hasattr(x, "keyword")
|
||||
and x.keyword == "digital marketing strategy",
|
||||
),
|
||||
("suggestions", lambda x: isinstance(x, list) and len(x) == 1),
|
||||
("total_count", 1),
|
||||
@@ -170,16 +167,6 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
|
||||
# DataForSEO reports per-task USD cost on the response. Feed it
|
||||
# into NodeExecutionStats so the COST_USD resolver bills the
|
||||
# real provider spend at reconciliation time.
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=client.last_cost_usd,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
|
||||
@@ -4,7 +4,6 @@ DataForSEO Google Related Keywords block.
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -178,16 +177,6 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
|
||||
# DataForSEO reports per-task USD cost on the response. Feed it
|
||||
# into NodeExecutionStats so the COST_USD resolver bills the
|
||||
# real provider spend at reconciliation time.
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=client.last_cost_usd,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
"""Provider registration for Discord — metadata only (auth lives in ``_auth.py``)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
discord = (
|
||||
ProviderBuilder("discord")
|
||||
.with_description("Messages, channels, and servers")
|
||||
.with_supported_auth_types("api_key", "oauth2")
|
||||
.build()
|
||||
)
|
||||
@@ -1,10 +0,0 @@
|
||||
"""Provider registration for ElevenLabs — metadata only (auth lives in ``_auth.py``)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
elevenlabs = (
|
||||
ProviderBuilder("elevenlabs")
|
||||
.with_description("Realistic AI voice synthesis")
|
||||
.with_supported_auth_types("api_key")
|
||||
.build()
|
||||
)
|
||||
@@ -1,10 +0,0 @@
|
||||
"""Provider registration for Enrichlayer — metadata only (auth lives in ``_auth.py``)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
enrichlayer = (
|
||||
ProviderBuilder("enrichlayer")
|
||||
.with_description("Enrich leads with company data")
|
||||
.with_supported_auth_types("api_key")
|
||||
.build()
|
||||
)
|
||||
@@ -9,14 +9,8 @@ from ._webhook import ExaWebhookManager
|
||||
# Configure the Exa provider once for all blocks
|
||||
exa = (
|
||||
ProviderBuilder("exa")
|
||||
.with_description("Neural web search")
|
||||
.with_api_key("EXA_API_KEY", "Exa API Key")
|
||||
.with_webhook_manager(ExaWebhookManager)
|
||||
# Exa returns `cost_dollars.total` on every response and ExaSearchBlock
|
||||
# (plus ~45 sibling blocks that share this provider config) already
|
||||
# populates NodeExecutionStats.provider_cost with it. Bill 100 credits
|
||||
# per USD (~$0.01/credit): cheap searches stay at 1–2 credits, a Deep
|
||||
# Research run at $0.20 lands at 20 credits, matching provider spend.
|
||||
.with_base_cost(100, BlockCostType.COST_USD)
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -17,7 +17,6 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
class AnswerCitation(BaseModel):
|
||||
@@ -112,7 +111,3 @@ class ExaAnswerBlock(Block):
|
||||
yield "citations", citations
|
||||
for citation in citations:
|
||||
yield "citation", citation
|
||||
|
||||
# Current SDK AnswerResponse dataclass omits cost_dollars; helper
|
||||
# no-ops today, but keeps billing wired when exa_py adds the field.
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -22,7 +23,6 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
class CodeContextResponse(BaseModel):
|
||||
@@ -118,5 +118,9 @@ class ExaCodeContextBlock(Block):
|
||||
yield "search_time", context.search_time
|
||||
yield "output_tokens", context.output_tokens
|
||||
|
||||
# API returns costDollars as a bare numeric string like "0.005".
|
||||
merge_exa_cost(self, data)
|
||||
# Parse cost_dollars (API returns as string, e.g. "0.005")
|
||||
try:
|
||||
cost_usd = float(context.cost_dollars)
|
||||
self.merge_stats(NodeExecutionStats(provider_cost=cost_usd))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
from exa_py import AsyncExa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -23,7 +24,6 @@ from .helpers import (
|
||||
HighlightSettings,
|
||||
LivecrawlTypes,
|
||||
SummarySettings,
|
||||
merge_exa_cost,
|
||||
)
|
||||
|
||||
|
||||
@@ -224,4 +224,6 @@ class ExaContentsBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
merge_exa_cost(self, response)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -143,9 +143,7 @@ class TestExaContentsCostTracking:
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT
|
||||
), # type: ignore[arg-type]
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
@@ -174,9 +172,7 @@ class TestExaContentsCostTracking:
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT
|
||||
), # type: ignore[arg-type]
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
@@ -205,9 +201,7 @@ class TestExaContentsCostTracking:
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT
|
||||
), # type: ignore[arg-type]
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
@@ -303,9 +297,7 @@ class TestExaSimilarCostTracking:
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
url="https://example.com", credentials=TEST_CREDENTIALS_INPUT
|
||||
), # type: ignore[arg-type]
|
||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
@@ -334,9 +326,7 @@ class TestExaSimilarCostTracking:
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
url="https://example.com", credentials=TEST_CREDENTIALS_INPUT
|
||||
), # type: ignore[arg-type]
|
||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import BaseModel, Block, MediaFileType, SchemaField
|
||||
from backend.sdk import BaseModel, MediaFileType, SchemaField
|
||||
|
||||
|
||||
class LivecrawlTypes(str, Enum):
|
||||
@@ -320,7 +319,7 @@ class CostDollars(BaseModel):
|
||||
|
||||
# Helper functions for payload processing
|
||||
def process_text_field(
|
||||
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None],
|
||||
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None]
|
||||
) -> Optional[Union[bool, Dict[str, Any]]]:
|
||||
"""Process text field for API payload."""
|
||||
if text is None:
|
||||
@@ -401,7 +400,7 @@ def process_contents_settings(contents: Optional[ContentSettings]) -> Dict[str,
|
||||
|
||||
|
||||
def process_context_field(
|
||||
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None],
|
||||
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None]
|
||||
) -> Optional[Union[bool, Dict[str, int]]]:
|
||||
"""Process context field for API payload."""
|
||||
if context is None:
|
||||
@@ -449,65 +448,3 @@ def add_optional_fields(
|
||||
payload[api_field] = value.value
|
||||
else:
|
||||
payload[api_field] = value
|
||||
|
||||
|
||||
def extract_exa_cost_usd(response: Any) -> Optional[float]:
|
||||
"""Return ``cost_dollars.total`` (USD) from an Exa SDK response, or None.
|
||||
|
||||
Handles dataclass/pydantic responses (``response.cost_dollars.total``),
|
||||
dicts with camelCase keys (``response["costDollars"]["total"]``), dicts
|
||||
with snake_case keys, and bare numeric strings. Returns None whenever the
|
||||
shape is missing cost info — the caller then skips merge_stats.
|
||||
"""
|
||||
if response is None:
|
||||
return None
|
||||
|
||||
# Dataclass / pydantic: response.cost_dollars
|
||||
cost_obj = getattr(response, "cost_dollars", None)
|
||||
|
||||
# Dict payloads: try both camelCase and snake_case
|
||||
if cost_obj is None and isinstance(response, dict):
|
||||
cost_obj = response.get("costDollars") or response.get("cost_dollars")
|
||||
|
||||
if cost_obj is None:
|
||||
return None
|
||||
|
||||
# Already a scalar (code_context endpoint returns a string)
|
||||
if isinstance(cost_obj, (int, float)):
|
||||
return max(0.0, float(cost_obj))
|
||||
if isinstance(cost_obj, str):
|
||||
try:
|
||||
return max(0.0, float(cost_obj))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
# Nested object/dict: grab the `total` field
|
||||
total = getattr(cost_obj, "total", None)
|
||||
if total is None and isinstance(cost_obj, dict):
|
||||
total = cost_obj.get("total")
|
||||
|
||||
if total is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return max(0.0, float(total))
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def merge_exa_cost(block: Block, response: Any) -> None:
|
||||
"""Pull ``cost_dollars.total`` off an Exa response and merge it into stats.
|
||||
|
||||
No-op when the response shape has no cost info (e.g. webset CRUD where
|
||||
the SDK does not expose per-call pricing) — emission happens only when
|
||||
Exa actually reports a USD amount.
|
||||
"""
|
||||
cost_usd = extract_exa_cost_usd(response)
|
||||
if cost_usd is None:
|
||||
return
|
||||
block.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=cost_usd,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
"""Unit tests for exa/helpers cost-extraction + merge helpers."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.exa.helpers import extract_exa_cost_usd, merge_exa_cost
|
||||
from backend.data.model import NodeExecutionStats
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response, expected",
|
||||
[
|
||||
# Dataclass / SimpleNamespace with cost_dollars.total
|
||||
(SimpleNamespace(cost_dollars=SimpleNamespace(total=0.05)), 0.05),
|
||||
# Dict camelCase
|
||||
({"costDollars": {"total": 0.10}}, 0.10),
|
||||
# Dict snake_case
|
||||
({"cost_dollars": {"total": 0.07}}, 0.07),
|
||||
# code_context endpoint shape: plain numeric string
|
||||
(SimpleNamespace(cost_dollars="0.005"), 0.005),
|
||||
# Scalar float on cost_dollars directly
|
||||
(SimpleNamespace(cost_dollars=0.02), 0.02),
|
||||
# Scalar int on cost_dollars
|
||||
(SimpleNamespace(cost_dollars=3), 3.0),
|
||||
# Missing cost info — returns None
|
||||
({}, None),
|
||||
(SimpleNamespace(other="foo"), None),
|
||||
(None, None),
|
||||
# Nested total=None
|
||||
(SimpleNamespace(cost_dollars=SimpleNamespace(total=None)), None),
|
||||
# Invalid numeric string
|
||||
(SimpleNamespace(cost_dollars="not-a-number"), None),
|
||||
# Negative values clamp to 0
|
||||
(SimpleNamespace(cost_dollars=SimpleNamespace(total=-1.0)), 0.0),
|
||||
],
|
||||
)
|
||||
def test_extract_exa_cost_usd_handles_all_shapes(response, expected):
|
||||
assert extract_exa_cost_usd(response) == expected
|
||||
|
||||
|
||||
def test_merge_exa_cost_emits_stats_when_cost_present():
|
||||
block = MagicMock()
|
||||
response = SimpleNamespace(cost_dollars=SimpleNamespace(total=0.0421))
|
||||
merge_exa_cost(block, response)
|
||||
|
||||
block.merge_stats.assert_called_once()
|
||||
stats: NodeExecutionStats = block.merge_stats.call_args.args[0]
|
||||
assert stats.provider_cost == pytest.approx(0.0421)
|
||||
assert stats.provider_cost_type == "cost_usd"
|
||||
|
||||
|
||||
def test_merge_exa_cost_noops_when_no_cost():
|
||||
"""Webset CRUD endpoints don't surface cost_dollars today — the helper
|
||||
must silently skip instead of emitting a 0-cost telemetry record."""
|
||||
block = MagicMock()
|
||||
merge_exa_cost(block, SimpleNamespace(other_field="nothing"))
|
||||
block.merge_stats.assert_not_called()
|
||||
|
||||
|
||||
def test_merge_exa_cost_noops_when_response_is_none():
|
||||
block = MagicMock()
|
||||
merge_exa_cost(block, None)
|
||||
block.merge_stats.assert_not_called()
|
||||
@@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -25,7 +26,6 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
class ResearchModel(str, Enum):
|
||||
@@ -233,7 +233,11 @@ class ExaCreateResearchBlock(Block):
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
merge_exa_cost(self, research)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=research.cost_dollars.total
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
await asyncio.sleep(check_interval)
|
||||
@@ -348,7 +352,9 @@ class ExaGetResearchBlock(Block):
|
||||
yield "cost_searches", research.cost_dollars.num_searches
|
||||
yield "cost_pages", research.cost_dollars.num_pages
|
||||
yield "cost_reasoning_tokens", research.cost_dollars.reasoning_tokens
|
||||
merge_exa_cost(self, research)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=research.cost_dollars.total)
|
||||
)
|
||||
|
||||
yield "error_message", research.error
|
||||
|
||||
@@ -435,7 +441,9 @@ class ExaWaitForResearchBlock(Block):
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
merge_exa_cost(self, research)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=research.cost_dollars.total)
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -20,7 +21,6 @@ from .helpers import (
|
||||
ContentSettings,
|
||||
CostDollars,
|
||||
ExaSearchResults,
|
||||
merge_exa_cost,
|
||||
process_contents_settings,
|
||||
)
|
||||
|
||||
@@ -207,4 +207,6 @@ class ExaSearchBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
merge_exa_cost(self, response)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -19,7 +20,6 @@ from .helpers import (
|
||||
ContentSettings,
|
||||
CostDollars,
|
||||
ExaSearchResults,
|
||||
merge_exa_cost,
|
||||
process_contents_settings,
|
||||
)
|
||||
|
||||
@@ -168,4 +168,6 @@ class ExaFindSimilarBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
merge_exa_cost(self, response)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -39,7 +39,6 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
class SearchEntityType(str, Enum):
|
||||
@@ -395,7 +394,6 @@ class ExaCreateWebsetBlock(Block):
|
||||
metadata=input_data.metadata,
|
||||
)
|
||||
)
|
||||
merge_exa_cost(self, webset)
|
||||
|
||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||
|
||||
@@ -406,7 +404,6 @@ class ExaCreateWebsetBlock(Block):
|
||||
timeout=input_data.polling_timeout,
|
||||
poll_interval=5,
|
||||
)
|
||||
merge_exa_cost(self, final_webset)
|
||||
completion_time = time.time() - start_time
|
||||
|
||||
item_count = 0
|
||||
@@ -482,7 +479,6 @@ class ExaCreateOrFindWebsetBlock(Block):
|
||||
|
||||
try:
|
||||
webset = await aexa.websets.get(id=input_data.external_id)
|
||||
merge_exa_cost(self, webset)
|
||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||
|
||||
yield "webset", webset_result
|
||||
@@ -505,7 +501,6 @@ class ExaCreateOrFindWebsetBlock(Block):
|
||||
metadata=input_data.metadata,
|
||||
)
|
||||
)
|
||||
merge_exa_cost(self, webset)
|
||||
|
||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||
|
||||
@@ -560,7 +555,6 @@ class ExaUpdateWebsetBlock(Block):
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload)
|
||||
merge_exa_cost(self, sdk_webset)
|
||||
|
||||
status_str = (
|
||||
sdk_webset.status.value
|
||||
@@ -572,9 +566,8 @@ class ExaUpdateWebsetBlock(Block):
|
||||
yield "status", status_str
|
||||
yield "external_id", sdk_webset.external_id
|
||||
yield "metadata", sdk_webset.metadata or {}
|
||||
yield (
|
||||
"updated_at",
|
||||
(sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else ""),
|
||||
yield "updated_at", (
|
||||
sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else ""
|
||||
)
|
||||
|
||||
|
||||
@@ -628,7 +621,6 @@ class ExaListWebsetsBlock(Block):
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
websets_data = [
|
||||
w.model_dump(by_alias=True, exclude_none=True) for w in response.data
|
||||
@@ -687,7 +679,6 @@ class ExaGetWebsetBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, sdk_webset)
|
||||
|
||||
status_str = (
|
||||
sdk_webset.status.value
|
||||
@@ -715,13 +706,11 @@ class ExaGetWebsetBlock(Block):
|
||||
yield "enrichments", enrichments_data
|
||||
yield "monitors", monitors_data
|
||||
yield "metadata", sdk_webset.metadata or {}
|
||||
yield (
|
||||
"created_at",
|
||||
(sdk_webset.created_at.isoformat() if sdk_webset.created_at else ""),
|
||||
yield "created_at", (
|
||||
sdk_webset.created_at.isoformat() if sdk_webset.created_at else ""
|
||||
)
|
||||
yield (
|
||||
"updated_at",
|
||||
(sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else ""),
|
||||
yield "updated_at", (
|
||||
sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else ""
|
||||
)
|
||||
|
||||
|
||||
@@ -760,7 +749,6 @@ class ExaDeleteWebsetBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_webset = await aexa.websets.delete(id=input_data.webset_id)
|
||||
merge_exa_cost(self, deleted_webset)
|
||||
|
||||
status_str = (
|
||||
deleted_webset.status.value
|
||||
@@ -811,7 +799,6 @@ class ExaCancelWebsetBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
canceled_webset = await aexa.websets.cancel(id=input_data.webset_id)
|
||||
merge_exa_cost(self, canceled_webset)
|
||||
|
||||
status_str = (
|
||||
canceled_webset.status.value
|
||||
@@ -982,7 +969,6 @@ class ExaPreviewWebsetBlock(Block):
|
||||
payload["entity"] = entity
|
||||
|
||||
sdk_preview = await aexa.websets.preview(params=payload)
|
||||
merge_exa_cost(self, sdk_preview)
|
||||
|
||||
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
||||
|
||||
@@ -1066,7 +1052,6 @@ class ExaWebsetStatusBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
|
||||
status = (
|
||||
webset.status.value
|
||||
@@ -1201,7 +1186,6 @@ class ExaWebsetSummaryBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
|
||||
# Extract basic info
|
||||
webset_id = webset.id
|
||||
@@ -1230,7 +1214,6 @@ class ExaWebsetSummaryBlock(Block):
|
||||
items_response = await aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||
)
|
||||
merge_exa_cost(self, items_response)
|
||||
sample_items_data = [
|
||||
item.model_dump(by_alias=True, exclude_none=True)
|
||||
for item in items_response.data
|
||||
@@ -1380,7 +1363,6 @@ class ExaWebsetReadyCheckBlock(Block):
|
||||
|
||||
# Get webset details
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
|
||||
status = (
|
||||
webset.status.value
|
||||
|
||||
@@ -25,7 +25,6 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
# Mirrored model for stability
|
||||
@@ -206,7 +205,6 @@ class ExaCreateEnrichmentBlock(Block):
|
||||
sdk_enrichment = await aexa.websets.enrichments.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
merge_exa_cost(self, sdk_enrichment)
|
||||
|
||||
enrichment_id = sdk_enrichment.id
|
||||
status = (
|
||||
@@ -228,7 +226,6 @@ class ExaCreateEnrichmentBlock(Block):
|
||||
current_enrich = await aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=enrichment_id
|
||||
)
|
||||
merge_exa_cost(self, current_enrich)
|
||||
current_status = (
|
||||
current_enrich.status.value
|
||||
if hasattr(current_enrich.status, "value")
|
||||
@@ -238,7 +235,6 @@ class ExaCreateEnrichmentBlock(Block):
|
||||
if current_status in ["completed", "failed", "cancelled"]:
|
||||
# Estimate items from webset searches
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.progress:
|
||||
@@ -336,7 +332,6 @@ class ExaGetEnrichmentBlock(Block):
|
||||
sdk_enrichment = await aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
merge_exa_cost(self, sdk_enrichment)
|
||||
|
||||
enrichment = WebsetEnrichmentModel.from_sdk(sdk_enrichment)
|
||||
|
||||
@@ -430,7 +425,6 @@ class ExaUpdateEnrichmentBlock(Block):
|
||||
try:
|
||||
response = await Requests().patch(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
# PATCH /websets/{id}/enrichments/{id} doesn't return costDollars.
|
||||
|
||||
yield "enrichment_id", data.get("id", "")
|
||||
yield "status", data.get("status", "")
|
||||
@@ -483,7 +477,6 @@ class ExaDeleteEnrichmentBlock(Block):
|
||||
deleted_enrichment = await aexa.websets.enrichments.delete(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
merge_exa_cost(self, deleted_enrichment)
|
||||
|
||||
yield "enrichment_id", deleted_enrichment.id
|
||||
yield "success", "true"
|
||||
@@ -535,14 +528,12 @@ class ExaCancelEnrichmentBlock(Block):
|
||||
canceled_enrichment = await aexa.websets.enrichments.cancel(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
merge_exa_cost(self, canceled_enrichment)
|
||||
|
||||
# Try to estimate how many items were enriched before cancellation
|
||||
items_enriched = 0
|
||||
items_response = await aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=100
|
||||
)
|
||||
merge_exa_cost(self, items_response)
|
||||
|
||||
for sdk_item in items_response.data:
|
||||
# Check if this enrichment is present
|
||||
|
||||
@@ -29,7 +29,6 @@ from backend.sdk import (
|
||||
|
||||
from ._config import exa
|
||||
from ._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
# Mirrored model for stability - don't use SDK types directly in block outputs
|
||||
@@ -298,7 +297,6 @@ class ExaCreateImportBlock(Block):
|
||||
sdk_import = await aexa.websets.imports.create(
|
||||
params=payload, csv_data=input_data.csv_data
|
||||
)
|
||||
merge_exa_cost(self, sdk_import)
|
||||
|
||||
import_obj = ImportModel.from_sdk(sdk_import)
|
||||
|
||||
@@ -363,7 +361,6 @@ class ExaGetImportBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id)
|
||||
merge_exa_cost(self, sdk_import)
|
||||
|
||||
import_obj = ImportModel.from_sdk(sdk_import)
|
||||
|
||||
@@ -433,7 +430,6 @@ class ExaListImportsBlock(Block):
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
# Convert SDK imports to our stable models
|
||||
imports = [ImportModel.from_sdk(i) for i in response.data]
|
||||
@@ -481,7 +477,6 @@ class ExaDeleteImportBlock(Block):
|
||||
deleted_import = await aexa.websets.imports.delete(
|
||||
import_id=input_data.import_id
|
||||
)
|
||||
merge_exa_cost(self, deleted_import)
|
||||
|
||||
yield "import_id", deleted_import.id
|
||||
yield "success", "true"
|
||||
@@ -604,7 +599,7 @@ class ExaExportWebsetBlock(Block):
|
||||
try:
|
||||
all_items = []
|
||||
|
||||
# list_all paginates internally; cost_dollars is not surfaced per-page
|
||||
# Use SDK's list_all iterator to fetch items
|
||||
item_iterator = aexa.websets.items.list_all(
|
||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user