mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Compare commits
86 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7cc1edc61f | ||
|
|
4a1741cc15 | ||
|
|
c08b9774dc | ||
|
|
fe3d6fb118 | ||
|
|
c6d31f8252 | ||
|
|
28ae7ebac8 | ||
|
|
e0f9146d54 | ||
|
|
c3c2737c42 | ||
|
|
37f247c795 | ||
|
|
ae4a421620 | ||
|
|
2879528308 | ||
|
|
1974ec6260 | ||
|
|
932ecd3a07 | ||
|
|
4a567a55a4 | ||
|
|
2b28434786 | ||
|
|
5d1cdc2bad | ||
|
|
3c08b90500 | ||
|
|
599f370206 | ||
|
|
8786c00f9c | ||
|
|
384cbd3ccd | ||
|
|
8be9cf70af | ||
|
|
a723966e0b | ||
|
|
5b1d9763ed | ||
|
|
10ea46663f | ||
|
|
06188a86a6 | ||
|
|
2deac2073e | ||
|
|
24406dfcec | ||
|
|
000ddb007a | ||
|
|
408b205515 | ||
|
|
f8c123a8c3 | ||
|
|
34374dfd55 | ||
|
|
2cb52e5d19 | ||
|
|
ab88d03b13 | ||
|
|
3aa72b4245 | ||
|
|
cc1f692fec | ||
|
|
be61dc4304 | ||
|
|
575f75edf4 | ||
|
|
0f6eea06c4 | ||
|
|
43b38f6989 | ||
|
|
10e421cd3e | ||
|
|
80bfde1ca6 | ||
|
|
81d6e91f37 | ||
|
|
39cdc0a5e0 | ||
|
|
4242da79f0 | ||
|
|
cf6d7034fa | ||
|
|
c56c1e5dd6 | ||
|
|
6fcbe95645 | ||
|
|
9703da3dfd | ||
|
|
ebb0d3b95b | ||
|
|
b98bcf31c8 | ||
|
|
4f11867d92 | ||
|
|
33a608ec78 | ||
|
|
e3f6d36759 | ||
|
|
c1b9ed1f5e | ||
|
|
45bc167184 | ||
|
|
e4f291e54b | ||
|
|
6efbc59fd8 | ||
|
|
6924cf90a5 | ||
|
|
07e5a6a9e4 | ||
|
|
a098f01bd2 | ||
|
|
59273fe6a0 | ||
|
|
38c2844b83 | ||
|
|
24850e2a3e | ||
|
|
e17e9f13c4 | ||
|
|
f238c153a5 | ||
|
|
01f1289aac | ||
|
|
343222ace1 | ||
|
|
a8226af725 | ||
|
|
f06b5293de | ||
|
|
70b591d74f | ||
|
|
b1c043c2d8 | ||
|
|
fcaebd1bb7 | ||
|
|
1c0c7a6b44 | ||
|
|
3a01874911 | ||
|
|
6d770d9917 | ||
|
|
334ec18c31 | ||
|
|
ea5cfdfa2e | ||
|
|
d13a85bef7 | ||
|
|
60b85640e7 | ||
|
|
87e4d42750 | ||
|
|
0339d95d12 | ||
|
|
f410929560 | ||
|
|
2bbec09e1a | ||
|
|
31b88a6e56 | ||
|
|
d357956d98 | ||
|
|
697ffa81f0 |
@@ -25,6 +25,8 @@ Understand the **Why / What / How** before addressing comments — you need cont
|
|||||||
gh pr view {N} --json body --jq '.body'
|
gh pr view {N} --json body --jq '.body'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> If GraphQL is rate-limited, `gh pr view` fails. See [GitHub rate limits](#github-rate-limits) for REST fallbacks.
|
||||||
|
|
||||||
## Fetch comments (all sources)
|
## Fetch comments (all sources)
|
||||||
|
|
||||||
### 1. Inline review threads — GraphQL (primary source of actionable items)
|
### 1. Inline review threads — GraphQL (primary source of actionable items)
|
||||||
@@ -109,12 +111,16 @@ Only after this loop completes (all pages fetched, count confirmed) should you b
|
|||||||
|
|
||||||
**Filter to unresolved threads only** — skip any thread where `isResolved: true`. `comments(last: 1)` returns the most recent comment in the thread — act on that; it reflects the reviewer's final ask. Use the thread `id` (Relay global ID) to track threads across polls.
|
**Filter to unresolved threads only** — skip any thread where `isResolved: true`. `comments(last: 1)` returns the most recent comment in the thread — act on that; it reflects the reviewer's final ask. Use the thread `id` (Relay global ID) to track threads across polls.
|
||||||
|
|
||||||
|
> If GraphQL is rate-limited, see [GitHub rate limits](#github-rate-limits) for the REST fallback (flat comment list — no thread grouping or `isResolved`).
|
||||||
|
|
||||||
### 2. Top-level reviews — REST (MUST paginate)
|
### 2. Top-level reviews — REST (MUST paginate)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
|
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **Already REST — unaffected by GraphQL rate limits or outages. Continue polling reviews normally even when GraphQL is exhausted.**
|
||||||
|
|
||||||
**CRITICAL — always `--paginate`.** Reviews default to 30 per page. PRs can have 80–170+ reviews (mostly empty resolution events). Without pagination you miss reviews past position 30 — including `autogpt-reviewer`'s structured review which is typically posted after several CI runs and sits well beyond the first page.
|
**CRITICAL — always `--paginate`.** Reviews default to 30 per page. PRs can have 80–170+ reviews (mostly empty resolution events). Without pagination you miss reviews past position 30 — including `autogpt-reviewer`'s structured review which is typically posted after several CI runs and sits well beyond the first page.
|
||||||
|
|
||||||
Two things to extract:
|
Two things to extract:
|
||||||
@@ -133,6 +139,8 @@ Two things to extract:
|
|||||||
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
|
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **Already REST — unaffected by GraphQL rate limits.**
|
||||||
|
|
||||||
Mostly contains: bot summaries (`coderabbitai[bot]`), CI/conflict detection (`github-actions[bot]`), and author status updates. Scan for non-empty messages from non-bot human reviewers that aren't the PR author — those are the ones that need a response.
|
Mostly contains: bot summaries (`coderabbitai[bot]`), CI/conflict detection (`github-actions[bot]`), and author status updates. Scan for non-empty messages from non-bot human reviewers that aren't the PR author — those are the ones that need a response.
|
||||||
|
|
||||||
## For each unaddressed comment
|
## For each unaddressed comment
|
||||||
@@ -327,18 +335,65 @@ git push
|
|||||||
|
|
||||||
5. Restart the polling loop from the top — new commits reset CI status.
|
5. Restart the polling loop from the top — new commits reset CI status.
|
||||||
|
|
||||||
## GitHub abuse rate limits
|
## GitHub rate limits
|
||||||
|
|
||||||
Two distinct rate limits exist — they have different causes and recovery times:
|
Three distinct rate limits exist — they have different causes, error shapes, and recovery times:
|
||||||
|
|
||||||
| Error | HTTP code | Cause | Recovery |
|
| Error | HTTP code | Cause | Recovery |
|
||||||
|---|---|---|---|
|
|---|---|---|---|
|
||||||
| `{"code":"abuse"}` | 403 | Secondary rate limit — too many write operations (comments, mutations) in a short window | Wait **2–3 minutes**. 60s is often not enough. |
|
| `{"code":"abuse"}` | 403 | Secondary rate limit — too many write operations (comments, mutations) in a short window | Wait **2–3 minutes**. 60s is often not enough. |
|
||||||
| `{"message":"API rate limit exceeded"}` | 429 | Primary rate limit — too many API calls per hour | Wait until `X-RateLimit-Reset` header timestamp |
|
| `{"message":"API rate limit exceeded"}` | 429 | Primary REST rate limit — 5000 calls/hr per user | Wait until `X-RateLimit-Reset` header timestamp |
|
||||||
|
| `GraphQL: API rate limit already exceeded for user ID ...` | 403 on stderr, `gh` exits 1 | **GraphQL-specific** per-user limit — distinct from REST's 5000/hr and from the abuse secondary limit. Trips faster than REST because point costs per query. | Wait until the GraphQL window resets (typically ~1 hour from the first call in the window). REST still works — use fallbacks below. |
|
||||||
|
|
||||||
**Prevention:** Add `sleep 3` between individual thread reply API calls. When posting >20 replies, increase to `sleep 5`.
|
**Prevention:** Add `sleep 3` between individual thread reply API calls. When posting >20 replies, increase to `sleep 5`.
|
||||||
|
|
||||||
**Recovery from secondary rate limit (403):**
|
### Detection
|
||||||
|
|
||||||
|
The `gh` CLI surfaces the GraphQL limit on stderr with the exact string `GraphQL: API rate limit already exceeded for user ID <id>` and exits 1 — any `gh api graphql ...` **or** `gh pr view ...` call fails. Check current quota and reset time via the REST endpoint that reports GraphQL quota (this call is REST and still works whether GraphQL is rate-limited OR fully down):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
gh api rate_limit --jq '.resources.graphql' # { "limit": 5000, "used": 5000, "remaining": 0, "reset": 1729...}
|
||||||
|
# Human-readable reset:
|
||||||
|
gh api rate_limit --jq '.resources.graphql.reset' | xargs -I{} date -r {}
|
||||||
|
```
|
||||||
|
|
||||||
|
Retry when `remaining > 0`. If you need to proceed sooner, sleep 2–5 min and probe again — the limit is per user, not per machine, so other concurrent agents under the same token also consume it.
|
||||||
|
|
||||||
|
### What keeps working
|
||||||
|
|
||||||
|
When GraphQL is unavailable (rate-limited or outage):
|
||||||
|
|
||||||
|
- **Keeps working (REST):** top-level reviews fetch, conversation comments fetch, all inline-comment replies, CI status (`gh pr checks`), and the `gh api rate_limit` probe.
|
||||||
|
- **Degraded:** inline thread list — fall back to flat `/pulls/{N}/comments` REST, which drops thread grouping, `isResolved`, and Relay thread IDs. You still get comment bodies and the `databaseId` as `id`, enough to read and reply.
|
||||||
|
- **Blocked:** `gh pr view`, the `resolveReviewThread` mutation, and any new `gh api graphql` queries — wait for the quota to reset.
|
||||||
|
|
||||||
|
### Fall back to REST
|
||||||
|
|
||||||
|
**PR metadata reads** — `gh pr view` uses GraphQL under the hood; use the REST pulls endpoint instead, which returns the full PR object:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.body' # == --json body
|
||||||
|
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.base.ref' # == --json baseRefName
|
||||||
|
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.mergeable' # == --json mergeable
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: REST `mergeable` returns `true|false|null`; GraphQL returns `MERGEABLE|CONFLICTING|UNKNOWN`. The `null` case maps to `UNKNOWN` — treat it the same (still computing; poll again).
|
||||||
|
|
||||||
|
**Inline comments (flat list)** — no thread grouping or `isResolved`, but enough to read and reply:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate \
|
||||||
|
| jq '[.[] | {id, path, line, user: .user.login, body: .body[:200], in_reply_to_id}]'
|
||||||
|
```
|
||||||
|
|
||||||
|
Use this degraded mode to make progress on the fix → reply loop, then return to GraphQL for `resolveReviewThread` once the rate limit resets.
|
||||||
|
|
||||||
|
**Replies** — already REST-native (`/pulls/{N}/comments/{ID}/replies`); no change needed, use the same command as the main flow.
|
||||||
|
|
||||||
|
**`resolveReviewThread`** — **no REST equivalent**; GitHub does not expose a REST endpoint for thread resolution. Queue the thread IDs needing resolution, wait for the GraphQL limit to reset, then run the resolve mutations in a batch (with `sleep 3` between calls, per the secondary-limit guidance).
|
||||||
|
|
||||||
|
### Recovery from secondary rate limit (403 abuse)
|
||||||
|
|
||||||
1. Stop all API writes immediately
|
1. Stop all API writes immediately
|
||||||
2. Wait **2 minutes minimum** (not 60s — secondary limits are stricter)
|
2. Wait **2 minutes minimum** (not 60s — secondary limits are stricter)
|
||||||
3. Resume with `sleep 3` between each call
|
3. Resume with `sleep 3` between each call
|
||||||
@@ -397,6 +452,8 @@ gh api graphql -f query='mutation { resolveReviewThread(input: {threadId: "THREA
|
|||||||
|
|
||||||
**Never call this mutation before committing the fix.** The orchestrator will verify actual unresolved counts via GraphQL after you output `ORCHESTRATOR:DONE` — false resolutions will be caught and you will be re-briefed.
|
**Never call this mutation before committing the fix.** The orchestrator will verify actual unresolved counts via GraphQL after you output `ORCHESTRATOR:DONE` — false resolutions will be caught and you will be re-briefed.
|
||||||
|
|
||||||
|
> `resolveReviewThread` is GraphQL-only — no REST equivalent. If GraphQL is rate-limited, see [GitHub rate limits](#github-rate-limits) for the queue-and-retry flow.
|
||||||
|
|
||||||
### Verify actual count before outputting ORCHESTRATOR:DONE
|
### Verify actual count before outputting ORCHESTRATOR:DONE
|
||||||
|
|
||||||
Before claiming "0 unresolved threads", always query GitHub directly — don't rely on your own bookkeeping. Paginate all pages — a single `first: 100` query misses threads beyond page 1:
|
Before claiming "0 unresolved threads", always query GitHub directly — don't rely on your own bookkeeping. Paginate all pages — a single `first: 100` query misses threads beyond page 1:
|
||||||
|
|||||||
275
.claude/skills/pr-polish/SKILL.md
Normal file
275
.claude/skills/pr-polish/SKILL.md
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
---
|
||||||
|
name: pr-polish
|
||||||
|
description: Alternate /pr-review and /pr-address on a PR until the PR is truly mergeable — no new review findings, zero unresolved inline threads, zero unaddressed top-level reviews or issue comments, all CI checks green, and two consecutive quiet polls after CI settles. Use when the user wants a PR polished to merge-ready without setting a fixed number of rounds.
|
||||||
|
user-invocable: true
|
||||||
|
argument-hint: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||||
|
metadata:
|
||||||
|
author: autogpt-team
|
||||||
|
version: "1.0.0"
|
||||||
|
---
|
||||||
|
|
||||||
|
# PR Polish
|
||||||
|
|
||||||
|
**Goal.** Drive a PR to merge-ready by alternating `/pr-review` and `/pr-address` until **all** of the following hold:
|
||||||
|
|
||||||
|
1. The most recent `/pr-review` produces **zero new findings** (no new inline comments, no new top-level reviews with a non-empty body).
|
||||||
|
2. Every inline review thread reachable via GraphQL reports `isResolved: true`.
|
||||||
|
3. Every non-bot, non-author top-level review has been acknowledged (replied-to) OR resolved via a thread it spawned.
|
||||||
|
4. Every non-bot, non-author issue comment has been acknowledged (replied-to).
|
||||||
|
5. Every CI check is `conclusion: "success"` or `"skipped"` / `"neutral"` — none `"failure"` or still pending.
|
||||||
|
6. **Two consecutive post-CI polls** (≥60s apart) stay clean — no new threads, no new non-empty reviews, no new issue comments. Bots (coderabbitai, sentry, autogpt-reviewer) frequently post late after CI settles; a single green snapshot is not sufficient.
|
||||||
|
|
||||||
|
**Do not stop at a fixed number of rounds.** If round N introduces new comments, round N+1 is required. Cap at `_MAX_ROUNDS = 10` as a safety valve, but expect 2–5 in practice.
|
||||||
|
|
||||||
|
## TodoWrite
|
||||||
|
|
||||||
|
Before starting, write two todos so the user can see the loop progression:
|
||||||
|
|
||||||
|
- `Round {current}: /pr-review + /pr-address on PR #{N}` — current iteration.
|
||||||
|
- `Final polish polling: 2 consecutive clean polls, CI green, 0 unresolved` — runs after the last non-empty review round.
|
||||||
|
|
||||||
|
Update the `current` round counter at the start of each iteration; mark `completed` only when the round's address step finishes (all new threads addressed + resolved).
|
||||||
|
|
||||||
|
## Find the PR
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ARG_PR="${ARG:-}"
|
||||||
|
# Normalize URL → numeric ID if the skill arg is a pull-request URL.
|
||||||
|
if [[ "$ARG_PR" =~ ^https?://github\.com/[^/]+/[^/]+/pull/([0-9]+) ]]; then
|
||||||
|
ARG_PR="${BASH_REMATCH[1]}"
|
||||||
|
fi
|
||||||
|
PR="${ARG_PR:-$(gh pr list --head "$(git branch --show-current)" --repo Significant-Gravitas/AutoGPT --json number --jq '.[0].number')}"
|
||||||
|
if [ -z "$PR" ] || [ "$PR" = "null" ]; then
|
||||||
|
echo "No PR found for current branch. Provide a PR number or URL as the skill arg."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "Polishing PR #$PR"
|
||||||
|
```
|
||||||
|
|
||||||
|
## The outer loop
|
||||||
|
|
||||||
|
```text
|
||||||
|
round = 0
|
||||||
|
while round < _MAX_ROUNDS:
|
||||||
|
round += 1
|
||||||
|
baseline = snapshot_state(PR) # see "Snapshotting state" below
|
||||||
|
invoke_skill("pr-review", PR) # posts findings as inline comments / top-level review
|
||||||
|
findings = diff_state(PR, baseline)
|
||||||
|
if findings.total == 0:
|
||||||
|
break # no new findings → go to polish polling
|
||||||
|
invoke_skill("pr-address", PR) # resolves every unresolved thread + CI failure
|
||||||
|
# Post-loop: polish polling (see below).
|
||||||
|
polish_polling(PR)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Snapshotting state
|
||||||
|
|
||||||
|
Before each `/pr-review`, capture a baseline so the diff after the review reflects **only** what the review just added (not pre-existing threads):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Inline threads — total count + latest databaseId per thread
|
||||||
|
gh api graphql -f query="
|
||||||
|
{
|
||||||
|
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
|
||||||
|
pullRequest(number: ${PR}) {
|
||||||
|
reviewThreads(first: 100) {
|
||||||
|
totalCount
|
||||||
|
nodes {
|
||||||
|
id
|
||||||
|
isResolved
|
||||||
|
comments(last: 1) { nodes { databaseId } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}" > /tmp/baseline_threads.json
|
||||||
|
|
||||||
|
# Top-level reviews — count + latest id per non-empty review
|
||||||
|
gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}/reviews" --paginate \
|
||||||
|
--jq '[.[] | select((.body // "") != "") | {id, user: .user.login, state, submitted_at}]' \
|
||||||
|
> /tmp/baseline_reviews.json
|
||||||
|
|
||||||
|
# Issue comments — count + latest id per non-bot, non-author comment.
|
||||||
|
# Bots are filtered by User.type == "Bot" (GitHub sets this for app/bot
|
||||||
|
# accounts like coderabbitai, github-actions, sentry-io). The author is
|
||||||
|
# filtered by comparing login to the PR author — export it so jq can see it.
|
||||||
|
AUTHOR=$(gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}" --jq '.user.login')
|
||||||
|
gh api "repos/Significant-Gravitas/AutoGPT/issues/${PR}/comments" --paginate \
|
||||||
|
--jq --arg author "$AUTHOR" \
|
||||||
|
'[.[] | select(.user.type != "Bot" and .user.login != $author)
|
||||||
|
| {id, user: .user.login, created_at}]' \
|
||||||
|
> /tmp/baseline_issue_comments.json
|
||||||
|
```
|
||||||
|
|
||||||
|
### Diffing after a review
|
||||||
|
|
||||||
|
After `/pr-review` runs, any of these counting as "new findings" means another address round is needed:
|
||||||
|
|
||||||
|
- New inline thread `id` not in the baseline.
|
||||||
|
- An existing thread whose latest comment `databaseId` is higher than the baseline's (new reply on an old thread).
|
||||||
|
- A new top-level review `id` with a non-empty body.
|
||||||
|
- A new issue comment `id` from a non-bot, non-author user.
|
||||||
|
|
||||||
|
If any of the four buckets is non-empty → not done; invoke `/pr-address` and loop.
|
||||||
|
|
||||||
|
## Polish polling
|
||||||
|
|
||||||
|
Once `/pr-review` produces zero new findings, do **not** exit yet. Bots (coderabbitai, sentry, autogpt-reviewer) commonly post late reviews after CI settles — 30–90 seconds after the final push. Poll at 60-second intervals:
|
||||||
|
|
||||||
|
```text
|
||||||
|
NON_SUCCESS_TERMINAL = {"failure", "cancelled", "timed_out", "action_required", "startup_failure"}
|
||||||
|
clean_polls = 0
|
||||||
|
required_clean = 2
|
||||||
|
while clean_polls < required_clean:
|
||||||
|
# 1. CI gate — any terminal non-success conclusion (not just "failure")
|
||||||
|
# must trigger /pr-address. "success", "skipped", "neutral" are clean;
|
||||||
|
# anything else (including cancelled, timed_out, action_required) is a
|
||||||
|
# blocker that won't self-resolve.
|
||||||
|
ci = fetch_check_runs(PR)
|
||||||
|
if any ci.conclusion in NON_SUCCESS_TERMINAL:
|
||||||
|
invoke_skill("pr-address", PR) # address failures + any new comments
|
||||||
|
baseline = snapshot_state(PR) # reset — push during address invalidates old baseline
|
||||||
|
clean_polls = 0
|
||||||
|
continue
|
||||||
|
if any ci.conclusion is None (still in_progress):
|
||||||
|
sleep 60; continue # wait without counting this as clean
|
||||||
|
|
||||||
|
# 2. Comment / thread gate
|
||||||
|
threads = fetch_unresolved_threads(PR)
|
||||||
|
new_issue_comments = diff_against_baseline(issue_comments)
|
||||||
|
new_reviews = diff_against_baseline(reviews)
|
||||||
|
if threads or new_issue_comments or new_reviews:
|
||||||
|
invoke_skill("pr-address", PR)
|
||||||
|
baseline = snapshot_state(PR) # reset — the address loop just dealt with these,
|
||||||
|
# otherwise they stay "new" relative to the old baseline forever
|
||||||
|
clean_polls = 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 3. Mergeability gate
|
||||||
|
mergeable = gh api repos/.../pulls/${PR} --jq '.mergeable'
|
||||||
|
if mergeable == false (CONFLICTING):
|
||||||
|
resolve_conflicts(PR) # see pr-address skill
|
||||||
|
clean_polls = 0
|
||||||
|
continue
|
||||||
|
if mergeable is null (UNKNOWN):
|
||||||
|
sleep 60; continue
|
||||||
|
|
||||||
|
clean_polls += 1
|
||||||
|
sleep 60
|
||||||
|
```
|
||||||
|
|
||||||
|
Only after `clean_polls == 2` do you report `ORCHESTRATOR:DONE`.
|
||||||
|
|
||||||
|
### Concrete CI fetch (don't parse `gh pr checks` text columns)
|
||||||
|
|
||||||
|
The `fetch_check_runs(PR)` step above must use `--json`, not the default text output. Job names can contain spaces and parentheses (e.g. `test (3.11)`, `Analyze (python)`), so `gh pr checks $PR | awk '{print $2}'` extracts `(3.11)` instead of the status — leading to a clean-poll firing while jobs are still pending.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Reliable: use --json so columns are unambiguous.
|
||||||
|
ci_json=$(gh pr checks $PR --repo Significant-Gravitas/AutoGPT --json name,state,bucket)
|
||||||
|
pending=$(echo "$ci_json" | jq '[.[] | select(.bucket == "pending")] | length')
|
||||||
|
failed=$(echo "$ci_json" | jq '[.[] | select(.bucket == "fail" or .bucket == "cancel")] | length')
|
||||||
|
|
||||||
|
# Buckets are: pass | fail | pending | cancel | skipping
|
||||||
|
# (NOTE: gh pr checks does NOT expose `conclusion` as a JSON field —
|
||||||
|
# only `bucket`. Don't confuse with the GitHub REST API's check_runs
|
||||||
|
# endpoint, which DOES use conclusion.)
|
||||||
|
```
|
||||||
|
|
||||||
|
Map back to the pseudocode above: `bucket == "pending"` is `ci.conclusion is None (still in_progress)`; `bucket in {"fail", "cancel"}` is `ci.conclusion in NON_SUCCESS_TERMINAL`; `bucket in {"pass", "skipping"}` is clean.
|
||||||
|
|
||||||
|
### Why 2 clean polls, not 1
|
||||||
|
|
||||||
|
A single green snapshot can be misleading — the final CI check often completes ~30s before a bot posts its delayed review. One quiet cycle does not prove the PR is stable; two consecutive cycles with no new threads, reviews, or issue comments arriving gives high confidence nothing else is incoming.
|
||||||
|
|
||||||
|
### Why checking every source each poll
|
||||||
|
|
||||||
|
`/pr-address` polling inside a single round already re-checks its own comments, but `/pr-polish` sits a level above and must also catch:
|
||||||
|
|
||||||
|
- New top-level reviews (autogpt-reviewer sometimes posts structured feedback only after several CI green cycles).
|
||||||
|
- Issue comments from human reviewers (not caught by inline thread polling).
|
||||||
|
- Sentry bug predictions that land on new line numbers post-push.
|
||||||
|
- Merge conflicts introduced by a race between your push and a merge to `dev`.
|
||||||
|
|
||||||
|
## Invocation pattern
|
||||||
|
|
||||||
|
Delegate to existing skills with the `Skill` tool; do not re-implement the review or address logic inline. This keeps the polish loop focused on orchestration and lets the child skills evolve independently.
|
||||||
|
|
||||||
|
```python
|
||||||
|
Skill(skill="pr-review", args=pr_url)
|
||||||
|
Skill(skill="pr-address", args=pr_url)
|
||||||
|
```
|
||||||
|
|
||||||
|
After each child invocation, re-query GitHub state directly — never trust a summary for the stop condition. The orchestrator's `ORCHESTRATOR:DONE` is verified against actual GraphQL / REST responses per the rules in `pr-address`'s "Verify actual count before outputting ORCHESTRATOR:DONE" section.
|
||||||
|
|
||||||
|
### **Auto-continue: do NOT end your response between child skills**
|
||||||
|
|
||||||
|
`/pr-polish` is a single orchestration task — one invocation drives the PR all the way to merge-ready. When a child `Skill()` call returns control to you:
|
||||||
|
|
||||||
|
- Do NOT summarize and stop.
|
||||||
|
- Do NOT wait for user confirmation to continue.
|
||||||
|
- Immediately, in the same response, perform the next loop step: state diff → decide next action → next `Skill()` call or polling sleep.
|
||||||
|
|
||||||
|
The child skill returning is a **loop iteration boundary**, not a conversation turn boundary. You are expected to keep going until one of the exit conditions in the opening section is met (2 consecutive clean polls, `_MAX_ROUNDS` hit, or an unrecoverable error).
|
||||||
|
|
||||||
|
If the user needs to approve a risky action mid-loop (e.g., a force-push or a destructive git operation), pause there — but not at the routine "round N finished, round N+1 needed" boundary. Those are silent transitions.
|
||||||
|
|
||||||
|
### **Run /pr-polish in the foreground — never in a background agent**
|
||||||
|
|
||||||
|
Spawning `/pr-polish` inside an `Agent(subagent_type="general-purpose")` background task **does not work**. Background agents don't inherit the parent's slash-command registry, so `Skill(skill="pr-review")` and `Skill(skill="pr-address")` calls aren't available — the agent has to manually replicate the child skills' logic, which is fragile and tends to stall on the first network or rate-limit hiccup. Symptom: the background task reports `stalled: no progress for 600s` mid-review.
|
||||||
|
|
||||||
|
Run `/pr-polish` inline in the foreground conversation. If the user asks for "/pr-polish + /pr-test in parallel", split them: foreground `/pr-polish`, and ONLY then can the test step go to a background agent (because `/pr-test` doesn't itself need to invoke skills).
|
||||||
|
|
||||||
|
### **You MUST invoke `Skill(pr-review)` every round — even when bot reviews already exist**
|
||||||
|
|
||||||
|
A common failure mode: CodeRabbit / autogpt-reviewer / Sentry have already posted findings on the PR, and the orchestrator skips the `Skill(pr-review)` step on the assumption that "review has been done." That's wrong — the outer loop's purpose is to layer **the agent's own review** on top of the bot reviews, catching issues the bots miss (architecture, naming, cross-file invariants, hidden coupling). If the orchestrator only addresses bot findings without ever running its own review, the loop converges to "bot-clean" but not "agent-reviewed-clean," and the user reasonably asks "did /pr-polish even read the diff?"
|
||||||
|
|
||||||
|
**Self-check before reporting `ORCHESTRATOR:DONE`:** confirm at least one `Skill(skill="pr-review")` call appears in the current orchestration. If none, the loop is incomplete — go back and run one round.
|
||||||
|
|
||||||
|
## GitHub rate limits
|
||||||
|
|
||||||
|
This skill issues many GraphQL calls (one review-thread query per outer iteration plus per-poll queries inside polish polling). Expect the GraphQL budget to be tight on large PRs. When `gh api rate_limit --jq .resources.graphql.remaining` drops below ~200, back off:
|
||||||
|
|
||||||
|
- Fall back to REST for reads (flat `/pulls/{N}/comments`, `/pulls/{N}/reviews`, `/issues/{N}/comments`) per the `pr-address` skill's GraphQL-fallback section.
|
||||||
|
- Queue thread resolutions (GraphQL-only) until the budget resets; keep making progress on fixes + REST replies meanwhile.
|
||||||
|
- `sleep 5` between any batch of ≥20 writes to avoid secondary rate limits.
|
||||||
|
|
||||||
|
## Safety valves
|
||||||
|
|
||||||
|
- `_MAX_ROUNDS = 10` — if review+address rounds exceed this, stop and escalate to the user with a summary of what's still unresolved. A PR that cannot converge in 10 rounds has systemic issues that need human judgment.
|
||||||
|
- After each commit, run `poetry run format` / `pnpm format && pnpm lint && pnpm types` per the target codebase's conventions. A failing format check is CI `failure` that will never self-resolve.
|
||||||
|
- Every `/pr-review` round checks for **duplicate** concerns first (via `pr-review`'s own "Fetch existing review comments" step) so the loop does not re-post the same finding that a prior round already resolved.
|
||||||
|
|
||||||
|
## Reporting
|
||||||
|
|
||||||
|
When the skill finishes (either via two clean polls or hitting `_MAX_ROUNDS`), produce a compact summary:
|
||||||
|
|
||||||
|
```
|
||||||
|
PR #{N} polish complete ({rounds_completed} rounds):
|
||||||
|
- {X} inline threads opened and resolved
|
||||||
|
- {Y} CI failures fixed
|
||||||
|
- {Z} new commits pushed
|
||||||
|
Final state: CI green, {total} threads all resolved, mergeable.
|
||||||
|
```
|
||||||
|
|
||||||
|
If exiting via `_MAX_ROUNDS`, flag explicitly:
|
||||||
|
|
||||||
|
```
|
||||||
|
PR #{N} polish stopped at {_MAX_ROUNDS} rounds — NOT merge-ready:
|
||||||
|
- {N} threads still unresolved: {titles}
|
||||||
|
- CI status: {summary}
|
||||||
|
Needs human review.
|
||||||
|
```
|
||||||
|
|
||||||
|
## When to use this skill
|
||||||
|
|
||||||
|
Use when the user says any of:
|
||||||
|
- "polish this PR"
|
||||||
|
- "keep reviewing and addressing until it's mergeable"
|
||||||
|
- "loop /pr-review + /pr-address until done"
|
||||||
|
- "make sure the PR is actually merge-ready"
|
||||||
|
|
||||||
|
Do **not** use when:
|
||||||
|
- User wants just one review pass (→ `/pr-review`).
|
||||||
|
- User wants to address already-posted comments without further self-review (→ `/pr-address`).
|
||||||
|
- A fixed round count is explicitly requested (e.g., "do 3 rounds") — honour the count instead of converging.
|
||||||
@@ -5,7 +5,7 @@ user-invocable: true
|
|||||||
argument-hint: "[worktree path or PR number] — tests the PR in the given worktree. Optional flags: --fix (auto-fix issues found)"
|
argument-hint: "[worktree path or PR number] — tests the PR in the given worktree. Optional flags: --fix (auto-fix issues found)"
|
||||||
metadata:
|
metadata:
|
||||||
author: autogpt-team
|
author: autogpt-team
|
||||||
version: "2.0.0"
|
version: "2.1.0"
|
||||||
---
|
---
|
||||||
|
|
||||||
# Manual E2E Test
|
# Manual E2E Test
|
||||||
@@ -180,6 +180,120 @@ Based on the PR analysis, write a test plan to `$RESULTS_DIR/test-plan.md`:
|
|||||||
|
|
||||||
**Be critical** — include edge cases, error paths, and security checks. Every scenario MUST specify what screenshots to take and what state to verify.
|
**Be critical** — include edge cases, error paths, and security checks. Every scenario MUST specify what screenshots to take and what state to verify.
|
||||||
|
|
||||||
|
## Step 3.0: Claim the testing lock (coordinate parallel agents)
|
||||||
|
|
||||||
|
Multiple worktrees share the same host — Docker infra (postgres, redis, clamav), app ports (3000/8006/…), and the test user. Two agents running `/pr-test` concurrently will corrupt each other's state (connection-pool exhaustion, port binds failing silently, cross-test assertions). Use the root-worktree lock file to take turns.
|
||||||
|
|
||||||
|
### Lock file contract
|
||||||
|
|
||||||
|
Path (**always** the root worktree so all siblings see it): `$REPO_ROOT/.ign.testing.lock`
|
||||||
|
|
||||||
|
Body (one `key=value` per line):
|
||||||
|
```
|
||||||
|
holder=<pr-XXXXX-purpose>
|
||||||
|
pid=<pid-or-"self">
|
||||||
|
started=<iso8601>
|
||||||
|
heartbeat=<iso8601, updated every ~2 min>
|
||||||
|
worktree=<full path>
|
||||||
|
branch=<branch name>
|
||||||
|
intent=<one-line description + rough duration>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Claim
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LOCK=$REPO_ROOT/.ign.testing.lock
|
||||||
|
NOW=$(date -u +%Y-%m-%dT%H:%MZ)
|
||||||
|
STALE_AFTER_MIN=5
|
||||||
|
|
||||||
|
if [ -f "$LOCK" ]; then
|
||||||
|
HB=$(grep '^heartbeat=' "$LOCK" | cut -d= -f2)
|
||||||
|
HB_EPOCH=$(date -j -f '%Y-%m-%dT%H:%MZ' "$HB" +%s 2>/dev/null || date -d "$HB" +%s 2>/dev/null || echo 0)
|
||||||
|
AGE_MIN=$(( ( $(date -u +%s) - HB_EPOCH ) / 60 ))
|
||||||
|
if [ "$AGE_MIN" -gt "$STALE_AFTER_MIN" ]; then
|
||||||
|
echo "WARN: stale lock (${AGE_MIN}m old) — reclaiming"
|
||||||
|
cat "$LOCK" | sed 's/^/ stale: /'
|
||||||
|
else
|
||||||
|
echo "Another agent holds the lock:"; cat "$LOCK"
|
||||||
|
echo "Wait until released or resume after $((STALE_AFTER_MIN - AGE_MIN))m."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
cat > "$LOCK" <<EOF
|
||||||
|
holder=pr-${PR_NUMBER}-e2e
|
||||||
|
pid=self
|
||||||
|
started=$NOW
|
||||||
|
heartbeat=$NOW
|
||||||
|
worktree=$WORKTREE_PATH
|
||||||
|
branch=$(cd $WORKTREE_PATH && git branch --show-current)
|
||||||
|
intent=E2E test PR #${PR_NUMBER}, native mode, ~60min
|
||||||
|
EOF
|
||||||
|
echo "Lock claimed"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Heartbeat (MUST run in background during the whole test)
|
||||||
|
|
||||||
|
Without a heartbeat a crashed agent keeps the lock forever. Run this as a background process right after claim:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
(while true; do
|
||||||
|
sleep 120
|
||||||
|
[ -f "$LOCK" ] || exit 0 # lock released → exit heartbeat
|
||||||
|
perl -i -pe "s/^heartbeat=.*/heartbeat=$(date -u +%Y-%m-%dT%H:%MZ)/" "$LOCK"
|
||||||
|
done) &
|
||||||
|
HEARTBEAT_PID=$!
|
||||||
|
echo "$HEARTBEAT_PID" > /tmp/pr-test-heartbeat.pid
|
||||||
|
```
|
||||||
|
|
||||||
|
### Release (always — even on failure)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
kill "$HEARTBEAT_PID" 2>/dev/null
|
||||||
|
rm -f "$LOCK" /tmp/pr-test-heartbeat.pid
|
||||||
|
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock" \
|
||||||
|
>> $REPO_ROOT/.ign.testing.log
|
||||||
|
```
|
||||||
|
|
||||||
|
Use a `trap` so release runs even on `exit 1`:
|
||||||
|
```bash
|
||||||
|
trap 'kill "$HEARTBEAT_PID" 2>/dev/null; rm -f "$LOCK"' EXIT INT TERM
|
||||||
|
```
|
||||||
|
|
||||||
|
### **Release the lock AS SOON AS the test run is done**
|
||||||
|
|
||||||
|
The lock guards **test execution**, not **app lifecycle**. Once Step 5 (record results) and Step 6 (post PR comment) are complete, release the lock IMMEDIATELY — even if:
|
||||||
|
|
||||||
|
- The native `poetry run app` / `pnpm dev` processes are still running so the user can keep poking at the app manually.
|
||||||
|
- You're leaving docker containers up.
|
||||||
|
- You're tailing logs for a minute or two.
|
||||||
|
|
||||||
|
Keeping the lock held past the test run is the single most common way `/pr-test` stalls other agents. **The app staying up is orthogonal to the lock; don't conflate them.** Sibling worktrees running their own `/pr-test` will kill the stray processes and free the ports themselves (Step 3c/3e-native handle that) — they just need the lock file gone.
|
||||||
|
|
||||||
|
Concretely, the sequence at the end of every `/pr-test` run (success or failure) is:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. Write the final report + post PR comment — done above in Step 5/6.
|
||||||
|
# 2. Release the lock right now, even if the app is still up.
|
||||||
|
kill "$HEARTBEAT_PID" 2>/dev/null
|
||||||
|
rm -f "$LOCK" /tmp/pr-test-heartbeat.pid
|
||||||
|
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock (app may still be running)" \
|
||||||
|
>> $REPO_ROOT/.ign.testing.log
|
||||||
|
# 3. Optionally leave the app running and note it so the user knows:
|
||||||
|
echo "Native stack still running on :3000 / :8006 for manual poking. Kill with:"
|
||||||
|
echo " pkill -9 -f 'poetry run app'; pkill -9 -f 'next-server|next dev'"
|
||||||
|
```
|
||||||
|
|
||||||
|
If a sibling agent's `/pr-test` needs to take over, it'll do the kill+rebuild dance from Step 3c/3e-native on its own — your only job is to not hold the lock file past the end of your test.
|
||||||
|
|
||||||
|
### Shared status log
|
||||||
|
|
||||||
|
`$REPO_ROOT/.ign.testing.log` is an append-only channel any agent can read/write. Use it for "I'm waiting", "I'm done, resources free", or post-run notes:
|
||||||
|
```bash
|
||||||
|
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] <message>" \
|
||||||
|
>> $REPO_ROOT/.ign.testing.log
|
||||||
|
```
|
||||||
|
|
||||||
## Step 3: Environment setup
|
## Step 3: Environment setup
|
||||||
|
|
||||||
### 3a. Copy .env files from the root worktree
|
### 3a. Copy .env files from the root worktree
|
||||||
@@ -248,7 +362,87 @@ docker ps --format "{{.Names}}" | grep -E "rest_server|executor|copilot|websocke
|
|||||||
done
|
done
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3e. Build and start
|
**Native mode also:** when running the app natively (see 3e-native), kill any stray host processes and free the app ports before starting — otherwise `poetry run app` and `pnpm dev` will fail to bind.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Kill stray native app processes from prior runs
|
||||||
|
pkill -9 -f "python.*backend" 2>/dev/null || true
|
||||||
|
pkill -9 -f "poetry run app" 2>/dev/null || true
|
||||||
|
pkill -9 -f "next-server|next dev" 2>/dev/null || true
|
||||||
|
|
||||||
|
# Free app ports (errors per port are ignored — port may simply be unused)
|
||||||
|
for port in 3000 8006 8001 8002 8005 8008; do
|
||||||
|
lsof -ti :$port -sTCP:LISTEN | xargs -r kill -9 2>/dev/null || true
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3e-native. Run the app natively (PREFERRED for iterative dev)
|
||||||
|
|
||||||
|
Native mode runs infra (postgres, supabase, redis, rabbitmq, clamav) in docker but runs the backend and frontend directly on the host. This avoids the 3-8 minute `docker compose build` cycle on every backend change — code edits are picked up on process restart (seconds) instead of a full image rebuild.
|
||||||
|
|
||||||
|
**When to prefer native mode (default for this skill):**
|
||||||
|
- Iterative dev/debug loops where you're editing backend or frontend code between test runs
|
||||||
|
- Any PR that touches Python/TS source but not Dockerfiles, compose config, or infra images
|
||||||
|
- Fast repro of a failing scenario — restart `poetry run app` in a couple of seconds
|
||||||
|
|
||||||
|
**When to prefer docker mode (3e fallback):**
|
||||||
|
- Testing changes to `Dockerfile`, `docker-compose.yml`, or base images
|
||||||
|
- Production-parity smoke tests (exact container env, networking, volumes)
|
||||||
|
- CI-equivalent runs where you need the exact image that'll ship
|
||||||
|
|
||||||
|
**Note on 3b (copilot auth):** no npm install anywhere. `poetry install` pulls in `claude_agent_sdk`, which ships its own Claude CLI binary — available on `PATH` whenever you run commands via `poetry run` (native) OR whenever the copilot_executor container is built from its Poetry lockfile (docker). The OAuth token extraction still applies (same `refresh_claude_token.sh` call).
|
||||||
|
|
||||||
|
**Preamble:** before starting native, run the kill-stray + free-ports block from 3c's "Native mode also" subsection.
|
||||||
|
|
||||||
|
**1. Start infra only (one-time per session):**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd $PLATFORM_DIR && docker compose --profile local up deps --detach --remove-orphans --build
|
||||||
|
```
|
||||||
|
|
||||||
|
This brings up postgres/supabase/redis/rabbitmq/clamav and skips all app services.
|
||||||
|
|
||||||
|
**2. Start the backend natively:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd $BACKEND_DIR && (poetry run app 2>&1 | tee .ign.application.logs) &
|
||||||
|
```
|
||||||
|
|
||||||
|
`poetry run app` spawns **all** app subprocesses — `rest_server`, `executor`, `copilot_executor`, `websocket`, `scheduler`, `notification_server`, `database_manager` — inside ONE parent process. No separate containers, no separate terminals. The `.ign.application.logs` prefix is already gitignored.
|
||||||
|
|
||||||
|
**3. Wait for the backend on :8006 BEFORE starting the frontend.** This ordering matters — the frontend's `pnpm dev` startup invokes `generate-api-queries`, which fetches `/openapi.json` from the backend. If the backend isn't listening yet, `pnpm dev` fails immediately.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
for i in $(seq 1 60); do
|
||||||
|
if [ "$(curl -s -o /dev/null -w '%{http_code}' http://localhost:8006/docs 2>/dev/null)" = "200" ]; then
|
||||||
|
echo "Backend ready"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
sleep 2
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
**4. Start the frontend natively:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd $FRONTEND_DIR && (pnpm dev 2>&1 | tee .ign.frontend.logs) &
|
||||||
|
```
|
||||||
|
|
||||||
|
**5. Wait for the frontend on :3000:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
for i in $(seq 1 60); do
|
||||||
|
if [ "$(curl -s -o /dev/null -w '%{http_code}' http://localhost:3000 2>/dev/null)" = "200" ]; then
|
||||||
|
echo "Frontend ready"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
sleep 2
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
Once both are up, skip 3e/3f and go straight to **3g/3h** (feature flags / test user creation).
|
||||||
|
|
||||||
|
### 3e. Build and start (docker — fallback)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd $PLATFORM_DIR && docker compose build --no-cache 2>&1 | tail -20
|
cd $PLATFORM_DIR && docker compose build --no-cache 2>&1 | tail -20
|
||||||
@@ -442,6 +636,22 @@ agent-browser --session-name pr-test snapshot | grep "text:"
|
|||||||
|
|
||||||
### Checking logs
|
### Checking logs
|
||||||
|
|
||||||
|
**Native mode:** when running via `poetry run app` + `pnpm dev`, all app logs stream to the `.ign.*.logs` files written by the `tee` pipes in 3e-native. `rest_server`, `executor`, `copilot_executor`, `websocket`, `scheduler`, `notification_server`, and `database_manager` are all subprocesses of the single `poetry run app` parent, so their output is interleaved in `.ign.application.logs`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Backend (all app subprocesses interleaved)
|
||||||
|
tail -f $BACKEND_DIR/.ign.application.logs
|
||||||
|
|
||||||
|
# Frontend (Next.js dev server)
|
||||||
|
tail -f $FRONTEND_DIR/.ign.frontend.logs
|
||||||
|
|
||||||
|
# Filter for errors across either log
|
||||||
|
grep -iE "error|exception|traceback" $BACKEND_DIR/.ign.application.logs | tail -20
|
||||||
|
grep -iE "error|exception|traceback" $FRONTEND_DIR/.ign.frontend.logs | tail -20
|
||||||
|
```
|
||||||
|
|
||||||
|
**Docker mode:**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Backend REST server
|
# Backend REST server
|
||||||
docker logs autogpt_platform-rest_server-1 2>&1 | tail -30
|
docker logs autogpt_platform-rest_server-1 2>&1 | tail -30
|
||||||
@@ -571,6 +781,19 @@ Upload screenshots to the PR using the GitHub Git API (no local git operations
|
|||||||
|
|
||||||
**CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.** Every screenshot MUST appear as `` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes.
|
||||||
|
|
||||||
|
**CRITICAL — NEVER paste absolute local paths into the PR comment.** Strings like `/Users/…`, `/home/…`, `C:\…` are useless to every reviewer except you. Before posting, grep the final body for `/Users/`, `/home/`, `/tmp/`, `/private/`, `C:\`, `~/` and either drop those lines entirely or rewrite them as repo-relative paths (`autogpt_platform/backend/…`). The PR comment is an artifact reviewers on GitHub read — it must be self-contained on github.com. Keep local paths in `$RESULTS_DIR/test-report.md` for yourself; only copy the *content* they reference (excerpts, test names, log lines) into the PR comment, not the path.
|
||||||
|
|
||||||
|
**Pre-post sanity check** (paste after building the comment body, before `gh api ... comments`):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Reject any local-looking absolute path or home-dir shortcut in the body
|
||||||
|
if grep -nE '(^|[^A-Za-z])(/Users/|/home/|/tmp/|/private/|C:\\|~/)[A-Za-z0-9]' "$COMMENT_FILE" ; then
|
||||||
|
echo "ABORT: local filesystem paths detected in PR comment body."
|
||||||
|
echo "Remove or rewrite as repo-relative (autogpt_platform/...) before posting."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
|
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
|
||||||
REPO="Significant-Gravitas/AutoGPT"
|
REPO="Significant-Gravitas/AutoGPT"
|
||||||
@@ -876,9 +1099,15 @@ test scenario → find issue (bug OR UX problem) → screenshot broken state
|
|||||||
### Problem: Frontend shows cookie banner blocking interaction
|
### Problem: Frontend shows cookie banner blocking interaction
|
||||||
**Fix:** `agent-browser click 'text=Accept All'` before other interactions.
|
**Fix:** `agent-browser click 'text=Accept All'` before other interactions.
|
||||||
|
|
||||||
### Problem: Container loses npm packages after rebuild
|
### Problem: Claude CLI not found in copilot_executor container
|
||||||
**Cause:** `docker compose up --build` rebuilds the image, losing runtime installs.
|
**Symptom:** Copilot logs say `claude: command not found` or similar when starting an SDK turn.
|
||||||
**Fix:** Add packages to the Dockerfile instead of installing at runtime.
|
**Cause:** Image was built without `poetry install` (stale base layer, or Dockerfile bypass). The SDK CLI ships inside the `claude_agent_sdk` Poetry dep — it is NOT an npm package.
|
||||||
|
**Fix:** Rebuild the image cleanly: `docker compose build --no-cache copilot_executor && docker compose up -d copilot_executor`. Do NOT `docker exec ... npm install -g @anthropic-ai/claude-code` — that is outdated guidance and will pollute the container with a second CLI that the SDK won't use.
|
||||||
|
|
||||||
|
### Problem: agent-browser screenshot hangs / times out
|
||||||
|
**Symptom:** `agent-browser screenshot` exits with code 124 even on `about:blank`.
|
||||||
|
**Cause:** Stuck CDP connection or Chromium process tree. Seen on macOS when a prior `/pr-test` left a zombie Chrome for Testing.
|
||||||
|
**Fix:** `pkill -9 -f "agent-browser|chromium|Chrome for Testing" && sleep 2`, then reopen the browser with a fresh `--session-name`. If still failing, verify via `agent-browser eval` + `agent-browser snapshot` (DOM state) instead of relying on PNGs — the feature under test is the same.
|
||||||
|
|
||||||
### Problem: Services not starting after `docker compose up`
|
### Problem: Services not starting after `docker compose up`
|
||||||
**Fix:** Wait and check health: `docker compose ps`. Common cause: migration hasn't finished. Check: `docker logs autogpt_platform-migrate-1 2>&1 | tail -5`. If supabase-db isn't healthy: `docker restart supabase-db && sleep 10`.
|
**Fix:** Wait and check health: `docker compose ps`. Common cause: migration hasn't finished. Check: `docker logs autogpt_platform-migrate-1 2>&1 | tail -5`. If supabase-db isn't healthy: `docker restart supabase-db && sleep 10`.
|
||||||
|
|||||||
79
.github/workflows/platform-backend-ci.yml
vendored
79
.github/workflows/platform-backend-ci.yml
vendored
@@ -119,10 +119,12 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
services:
|
services:
|
||||||
redis:
|
# Redis is provisioned as a real 3-shard cluster below via docker
|
||||||
image: redis:latest
|
# run (see the "Start Redis Cluster" step). GHA services can't
|
||||||
ports:
|
# override the image CMD or stand up multi-container clusters, so
|
||||||
- 6379:6379
|
# that setup is inlined — it mirrors the topology of the local dev
|
||||||
|
# compose stack (autogpt_platform/docker-compose.platform.yml) and
|
||||||
|
# prod helm chart.
|
||||||
rabbitmq:
|
rabbitmq:
|
||||||
image: rabbitmq:4.1.4
|
image: rabbitmq:4.1.4
|
||||||
ports:
|
ports:
|
||||||
@@ -166,6 +168,68 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Start Redis Cluster (3 shards)
|
||||||
|
run: |
|
||||||
|
# 3-master Redis Cluster matching the local compose stack
|
||||||
|
# (autogpt_platform/docker-compose.platform.yml) and prod. Each
|
||||||
|
# shard runs in its own container on a dedicated bridge network,
|
||||||
|
# announces its compose-style hostname for intra-network clients,
|
||||||
|
# and publishes 1700N on the GHA host so tests can reach every
|
||||||
|
# shard via localhost. The backend's ``_address_remap`` rewrites
|
||||||
|
# every CLUSTER SLOTS reply to localhost:<announced-port>, which
|
||||||
|
# picks the right published port per shard.
|
||||||
|
#
|
||||||
|
# Not reusing docker-compose.platform.yml directly because compose
|
||||||
|
# validates the full file even when only some services are ``up``,
|
||||||
|
# and that file references services (db/kong/...) defined in a
|
||||||
|
# sibling compose file — pulling both in would needlessly couple
|
||||||
|
# CI to the full local-dev stack.
|
||||||
|
docker network create redis-cluster-ci
|
||||||
|
for i in 0 1 2; do
|
||||||
|
port=$((17000 + i))
|
||||||
|
bus=$((27000 + i))
|
||||||
|
docker run -d --name redis-$i --network redis-cluster-ci \
|
||||||
|
--network-alias redis-$i \
|
||||||
|
-p $port:$port \
|
||||||
|
redis:7 \
|
||||||
|
redis-server --port $port \
|
||||||
|
--cluster-enabled yes \
|
||||||
|
--cluster-config-file nodes.conf \
|
||||||
|
--cluster-node-timeout 5000 \
|
||||||
|
--cluster-require-full-coverage no \
|
||||||
|
--cluster-announce-hostname redis-$i \
|
||||||
|
--cluster-announce-port $port \
|
||||||
|
--cluster-announce-bus-port $bus \
|
||||||
|
--cluster-preferred-endpoint-type hostname
|
||||||
|
done
|
||||||
|
# Wait for each shard to accept commands.
|
||||||
|
for i in 0 1 2; do
|
||||||
|
port=$((17000 + i))
|
||||||
|
for _ in $(seq 1 30); do
|
||||||
|
docker exec redis-$i redis-cli -p $port ping 2>/dev/null | grep -q PONG && break
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
done
|
||||||
|
# Form the cluster from an init container on the same network so
|
||||||
|
# --cluster-preferred-endpoint-type hostname resolves redis-0/1/2.
|
||||||
|
docker run --rm --network redis-cluster-ci redis:7 \
|
||||||
|
redis-cli --cluster create \
|
||||||
|
redis-0:17000 redis-1:17001 redis-2:17002 \
|
||||||
|
--cluster-replicas 0 --cluster-yes
|
||||||
|
# Confirm convergence.
|
||||||
|
for _ in $(seq 1 30); do
|
||||||
|
state=$(docker exec redis-0 redis-cli -p 17000 cluster info | awk -F: '/^cluster_state:/ {print $2}' | tr -d '[:cntrl:]')
|
||||||
|
if [ "$state" = "ok" ]; then
|
||||||
|
echo "Redis Cluster ready (3 shards, state=ok)"
|
||||||
|
docker exec redis-0 redis-cli -p 17000 cluster nodes
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
echo "Redis Cluster failed to reach ok state" >&2
|
||||||
|
docker exec redis-0 redis-cli -p 17000 cluster info >&2 || true
|
||||||
|
exit 1
|
||||||
|
|
||||||
- name: Setup Supabase
|
- name: Setup Supabase
|
||||||
uses: supabase/setup-cli@v1
|
uses: supabase/setup-cli@v1
|
||||||
with:
|
with:
|
||||||
@@ -286,8 +350,13 @@ jobs:
|
|||||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||||
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
|
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||||
REDIS_HOST: "localhost"
|
REDIS_HOST: "localhost"
|
||||||
REDIS_PORT: "6379"
|
REDIS_PORT: "17000"
|
||||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||||
|
# Opt-in: lets backend/data/e2e_redis_restart_test.py spin up its
|
||||||
|
# own isolated 3-shard cluster (ports 27110–27112) and exercise
|
||||||
|
# ``docker restart <shard>`` mid-stream. Off locally so a
|
||||||
|
# contributor's ``poetry run test`` doesn't pay the ~15s cost.
|
||||||
|
E2E_RESTART_ISOLATED: "1"
|
||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
if: ${{ !cancelled() }}
|
if: ${{ !cancelled() }}
|
||||||
|
|||||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -195,3 +195,8 @@ test.db
|
|||||||
# Implementation plans (generated by AI agents)
|
# Implementation plans (generated by AI agents)
|
||||||
plans/
|
plans/
|
||||||
.claude/worktrees/
|
.claude/worktrees/
|
||||||
|
test-results/
|
||||||
|
|
||||||
|
# Playwright MCP / local browser-testing artifacts
|
||||||
|
.playwright-mcp/
|
||||||
|
copilot-session-switch-qa/
|
||||||
|
|||||||
@@ -267,7 +267,7 @@
|
|||||||
"filename": "autogpt_platform/backend/backend/blocks/replicate/replicate_block.py",
|
"filename": "autogpt_platform/backend/backend/blocks/replicate/replicate_block.py",
|
||||||
"hashed_secret": "8bbdd6f26368f58ea4011d13d7f763cb662e66f0",
|
"hashed_secret": "8bbdd6f26368f58ea4011d13d7f763cb662e66f0",
|
||||||
"is_verified": false,
|
"is_verified": false,
|
||||||
"line_number": 55
|
"line_number": 67
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"autogpt_platform/backend/backend/blocks/slant3d/webhook.py": [
|
"autogpt_platform/backend/backend/blocks/slant3d/webhook.py": [
|
||||||
@@ -467,5 +467,5 @@
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"generated_at": "2026-04-09T14:20:23Z"
|
"generated_at": "2026-04-24T16:42:44Z"
|
||||||
}
|
}
|
||||||
|
|||||||
3
autogpt_platform/.gitignore
vendored
3
autogpt_platform/.gitignore
vendored
@@ -1,3 +1,6 @@
|
|||||||
*.ignore.*
|
*.ignore.*
|
||||||
*.ign.*
|
*.ign.*
|
||||||
.application.logs
|
.application.logs
|
||||||
|
|
||||||
|
# Claude Code local settings only — the rest of .claude/ is shared (skills etc.)
|
||||||
|
.claude/settings.local.json
|
||||||
|
|||||||
@@ -1,33 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
||||||
|
|
||||||
|
|
||||||
class RateLimitSettings(BaseSettings):
|
|
||||||
redis_host: str = Field(
|
|
||||||
default="redis://localhost:6379",
|
|
||||||
description="Redis host",
|
|
||||||
validation_alias="REDIS_HOST",
|
|
||||||
)
|
|
||||||
|
|
||||||
redis_port: str = Field(
|
|
||||||
default="6379", description="Redis port", validation_alias="REDIS_PORT"
|
|
||||||
)
|
|
||||||
|
|
||||||
redis_password: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="Redis password",
|
|
||||||
validation_alias="REDIS_PASSWORD",
|
|
||||||
)
|
|
||||||
|
|
||||||
requests_per_minute: int = Field(
|
|
||||||
default=60,
|
|
||||||
description="Maximum number of requests allowed per minute per API key",
|
|
||||||
validation_alias="RATE_LIMIT_REQUESTS_PER_MINUTE",
|
|
||||||
)
|
|
||||||
|
|
||||||
model_config = SettingsConfigDict(case_sensitive=True, extra="ignore")
|
|
||||||
|
|
||||||
|
|
||||||
RATE_LIMIT_SETTINGS = RateLimitSettings()
|
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
import time
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from redis import Redis
|
|
||||||
|
|
||||||
from .config import RATE_LIMIT_SETTINGS
|
|
||||||
|
|
||||||
|
|
||||||
class RateLimiter:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
redis_host: str = RATE_LIMIT_SETTINGS.redis_host,
|
|
||||||
redis_port: str = RATE_LIMIT_SETTINGS.redis_port,
|
|
||||||
redis_password: str | None = RATE_LIMIT_SETTINGS.redis_password,
|
|
||||||
requests_per_minute: int = RATE_LIMIT_SETTINGS.requests_per_minute,
|
|
||||||
):
|
|
||||||
self.redis = Redis(
|
|
||||||
host=redis_host,
|
|
||||||
port=int(redis_port),
|
|
||||||
password=redis_password,
|
|
||||||
decode_responses=True,
|
|
||||||
)
|
|
||||||
self.window = 60
|
|
||||||
self.max_requests = requests_per_minute
|
|
||||||
|
|
||||||
async def check_rate_limit(self, api_key_id: str) -> Tuple[bool, int, int]:
|
|
||||||
"""
|
|
||||||
Check if request is within rate limits.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key_id: The API key identifier to check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (is_allowed, remaining_requests, reset_time)
|
|
||||||
"""
|
|
||||||
now = time.time()
|
|
||||||
window_start = now - self.window
|
|
||||||
key = f"ratelimit:{api_key_id}:1min"
|
|
||||||
|
|
||||||
pipe = self.redis.pipeline()
|
|
||||||
pipe.zremrangebyscore(key, 0, window_start)
|
|
||||||
pipe.zadd(key, {str(now): now})
|
|
||||||
pipe.zcount(key, window_start, now)
|
|
||||||
pipe.expire(key, self.window)
|
|
||||||
|
|
||||||
_, _, request_count, _ = pipe.execute()
|
|
||||||
|
|
||||||
remaining = max(0, self.max_requests - request_count)
|
|
||||||
reset_time = int(now + self.window)
|
|
||||||
|
|
||||||
return request_count <= self.max_requests, remaining, reset_time
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
from fastapi import HTTPException, Request
|
|
||||||
from starlette.middleware.base import RequestResponseEndpoint
|
|
||||||
|
|
||||||
from .limiter import RateLimiter
|
|
||||||
|
|
||||||
|
|
||||||
async def rate_limit_middleware(request: Request, call_next: RequestResponseEndpoint):
|
|
||||||
"""FastAPI middleware for rate limiting API requests."""
|
|
||||||
limiter = RateLimiter()
|
|
||||||
|
|
||||||
if not request.url.path.startswith("/api"):
|
|
||||||
return await call_next(request)
|
|
||||||
|
|
||||||
api_key = request.headers.get("Authorization")
|
|
||||||
if not api_key:
|
|
||||||
return await call_next(request)
|
|
||||||
|
|
||||||
api_key = api_key.replace("Bearer ", "")
|
|
||||||
|
|
||||||
is_allowed, remaining, reset_time = await limiter.check_rate_limit(api_key)
|
|
||||||
|
|
||||||
if not is_allowed:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=429, detail="Rate limit exceeded. Please try again later."
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await call_next(request)
|
|
||||||
response.headers["X-RateLimit-Limit"] = str(limiter.max_requests)
|
|
||||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
|
||||||
response.headers["X-RateLimit-Reset"] = str(reset_time)
|
|
||||||
|
|
||||||
return response
|
|
||||||
@@ -59,6 +59,8 @@ class OAuthState(BaseModel):
|
|||||||
code_verifier: Optional[str] = None
|
code_verifier: Optional[str] = None
|
||||||
scopes: list[str]
|
scopes: list[str]
|
||||||
"""Unix timestamp (seconds) indicating when this OAuth state expires"""
|
"""Unix timestamp (seconds) indicating when this OAuth state expires"""
|
||||||
|
credential_id: Optional[str] = None
|
||||||
|
"""If set, this OAuth flow upgrades an existing credential's scopes."""
|
||||||
|
|
||||||
|
|
||||||
class UserMetadata(BaseModel):
|
class UserMetadata(BaseModel):
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
|
|
||||||
from expiringdict import ExpiringDict
|
from expiringdict import ExpiringDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from redis.asyncio import Redis as AsyncRedis
|
from redis.asyncio import Redis as AsyncRedis
|
||||||
|
from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster
|
||||||
from redis.asyncio.lock import Lock as AsyncRedisLock
|
from redis.asyncio.lock import Lock as AsyncRedisLock
|
||||||
|
|
||||||
|
AsyncRedisLike = Union[AsyncRedis, AsyncRedisCluster]
|
||||||
|
|
||||||
|
|
||||||
class AsyncRedisKeyedMutex:
|
class AsyncRedisKeyedMutex:
|
||||||
"""
|
"""
|
||||||
@@ -17,7 +20,7 @@ class AsyncRedisKeyedMutex:
|
|||||||
in case the key is not unlocked for a specified duration, to prevent memory leaks.
|
in case the key is not unlocked for a specified duration, to prevent memory leaks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, redis: "AsyncRedis", timeout: int | None = 60):
|
def __init__(self, redis: "AsyncRedisLike", timeout: int | None = 60):
|
||||||
self.redis = redis
|
self.redis = redis
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.locks: dict[Any, "AsyncRedisLock"] = ExpiringDict(
|
self.locks: dict[Any, "AsyncRedisLock"] = ExpiringDict(
|
||||||
|
|||||||
@@ -37,6 +37,23 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
|||||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||||
|
|
||||||
|
# Web Push (VAPID) — generate with: poetry run python -c "
|
||||||
|
# from py_vapid import Vapid; import base64
|
||||||
|
# from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
|
||||||
|
# v = Vapid(); v.generate_keys()
|
||||||
|
# raw_priv = v.private_key.private_numbers().private_value.to_bytes(32, 'big')
|
||||||
|
# print('VAPID_PRIVATE_KEY=' + base64.urlsafe_b64encode(raw_priv).rstrip(b'=').decode())
|
||||||
|
# raw_pub = v.public_key.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint)
|
||||||
|
# print('VAPID_PUBLIC_KEY=' + base64.urlsafe_b64encode(raw_pub).rstrip(b'=').decode())
|
||||||
|
# "
|
||||||
|
# Dev-only keypair below — DO NOT use in staging/production. Regenerate
|
||||||
|
# your own with the snippet above before any non-local deployment.
|
||||||
|
VAPID_PRIVATE_KEY=17hBPdSdn6TR_yAgQxA0TjTcvRj3Lf6znHnASZ4rOKc
|
||||||
|
VAPID_PUBLIC_KEY=BBg49iVTWthVbRYphwmZNvZyiSJDqtSO4nmLxDzLKe3Oo9jbtu0Usa14xX4HQQNLUeiEfzD42zWSlrvY1PR12bs
|
||||||
|
# Per RFC 8292 push services use this in 410 Gone reports; set to a real
|
||||||
|
# mailbox in production. Defaults to a placeholder for local dev.
|
||||||
|
VAPID_CLAIM_EMAIL=mailto:dev@example.com
|
||||||
|
|
||||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||||
# Platform URLs (set these for webhooks and OAuth to work)
|
# Platform URLs (set these for webhooks and OAuth to work)
|
||||||
PLATFORM_BASE_URL=http://localhost:8000
|
PLATFORM_BASE_URL=http://localhost:8000
|
||||||
@@ -179,6 +196,13 @@ MEM0_API_KEY=
|
|||||||
OPENWEATHERMAP_API_KEY=
|
OPENWEATHERMAP_API_KEY=
|
||||||
GOOGLE_MAPS_API_KEY=
|
GOOGLE_MAPS_API_KEY=
|
||||||
|
|
||||||
|
# Platform Bot Linking
|
||||||
|
PLATFORM_LINK_BASE_URL=http://localhost:3000/link
|
||||||
|
|
||||||
|
# CoPilot chat-platform bridge (Discord/Telegram/Slack)
|
||||||
|
# Uses FRONTEND_BASE_URL (above) for link confirmation pages.
|
||||||
|
AUTOPILOT_BOT_DISCORD_TOKEN=
|
||||||
|
|
||||||
# Communication Services
|
# Communication Services
|
||||||
DISCORD_BOT_TOKEN=
|
DISCORD_BOT_TOKEN=
|
||||||
MEDIUM_API_KEY=
|
MEDIUM_API_KEY=
|
||||||
|
|||||||
@@ -1,14 +1,44 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Dict, Set
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Awaitable, Callable, Dict, Optional, Set
|
||||||
|
|
||||||
from fastapi import WebSocket
|
from fastapi import WebSocket, WebSocketDisconnect
|
||||||
|
from redis.asyncio import Redis as AsyncRedis
|
||||||
|
from redis.asyncio.client import PubSub as AsyncPubSub
|
||||||
|
from redis.exceptions import MovedError, RedisError, ResponseError
|
||||||
|
from starlette.websockets import WebSocketState
|
||||||
|
|
||||||
from backend.api.model import NotificationPayload, WSMessage, WSMethod
|
from backend.api.model import WSMessage, WSMethod
|
||||||
|
from backend.data import redis_client as redis
|
||||||
|
from backend.data.event_bus import _assert_no_wildcard
|
||||||
from backend.data.execution import (
|
from backend.data.execution import (
|
||||||
ExecutionEventType,
|
ExecutionEventType,
|
||||||
GraphExecutionEvent,
|
exec_channel,
|
||||||
NodeExecutionEvent,
|
get_graph_execution_meta,
|
||||||
|
graph_all_channel,
|
||||||
)
|
)
|
||||||
|
from backend.data.notification_bus import NotificationEvent
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
_settings = Settings()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_ws_close_race(exc: BaseException, websocket: WebSocket) -> bool:
|
||||||
|
"""A SPUBLISH→WS send racing with WS close — benign, drop quietly."""
|
||||||
|
if isinstance(exc, WebSocketDisconnect):
|
||||||
|
return True
|
||||||
|
if (
|
||||||
|
getattr(websocket, "application_state", None) == WebSocketState.DISCONNECTED
|
||||||
|
or getattr(websocket, "client_state", None) == WebSocketState.DISCONNECTED
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
if isinstance(exc, RuntimeError) and "close message has been sent" in str(exc):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
|
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
|
||||||
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
|
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
|
||||||
@@ -16,128 +46,379 @@ _EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def event_bus_channel(channel_key: str) -> str:
|
||||||
|
"""Prefix a channel key with the execution event bus name."""
|
||||||
|
return f"{_settings.config.execution_event_bus_name}/{channel_key}"
|
||||||
|
|
||||||
|
|
||||||
|
def _notification_bus_channel(user_id: str) -> str:
|
||||||
|
"""Return the full sharded channel name for a user's notifications."""
|
||||||
|
return f"{_settings.config.notification_event_bus_name}/{user_id}"
|
||||||
|
|
||||||
|
|
||||||
|
MessageHandler = Callable[[Optional[bytes | str]], Awaitable[None]]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_moved_error(exc: BaseException) -> bool:
|
||||||
|
"""A MOVED redirect — slot migration mid-stream; pump should reconnect."""
|
||||||
|
if isinstance(exc, MovedError):
|
||||||
|
return True
|
||||||
|
if isinstance(exc, ResponseError) and str(exc).startswith("MOVED "):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Reconnect tunables for shard-failover during pubsub.listen().
|
||||||
|
_PUMP_RECONNECT_DEADLINE_S = 60.0
|
||||||
|
_PUMP_RECONNECT_BACKOFF_INITIAL_S = 0.5
|
||||||
|
_PUMP_RECONNECT_BACKOFF_MAX_S = 8.0
|
||||||
|
|
||||||
|
|
||||||
|
class _Subscription:
|
||||||
|
"""One SSUBSCRIBE lifecycle bound to a WebSocket, pinned to the owning shard."""
|
||||||
|
|
||||||
|
def __init__(self, full_channel: str) -> None:
|
||||||
|
_assert_no_wildcard(full_channel)
|
||||||
|
self.full_channel = full_channel
|
||||||
|
self._client: AsyncRedis | None = None
|
||||||
|
self._pubsub: AsyncPubSub | None = None
|
||||||
|
self._task: asyncio.Task | None = None
|
||||||
|
|
||||||
|
async def start(self, on_message: MessageHandler) -> None:
|
||||||
|
await self._open_pubsub()
|
||||||
|
self._task = asyncio.create_task(self._pump(on_message))
|
||||||
|
|
||||||
|
async def _open_pubsub(self) -> None:
|
||||||
|
"""(Re)establish the sharded pubsub connection + SSUBSCRIBE."""
|
||||||
|
self._client = await redis.connect_sharded_pubsub_async(self.full_channel)
|
||||||
|
self._pubsub = self._client.pubsub()
|
||||||
|
await self._pubsub.execute_command("SSUBSCRIBE", self.full_channel)
|
||||||
|
# redis-py 6.x async PubSub.listen() exits when ``channels`` is
|
||||||
|
# empty; raw SSUBSCRIBE doesn't populate it, so do it ourselves.
|
||||||
|
self._pubsub.channels[self.full_channel] = None # type: ignore[index]
|
||||||
|
|
||||||
|
async def _close_pubsub_quietly(self) -> None:
|
||||||
|
"""Best-effort teardown before reconnect — never raises."""
|
||||||
|
if self._pubsub is not None:
|
||||||
|
try:
|
||||||
|
await self._pubsub.aclose()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._pubsub = None
|
||||||
|
if self._client is not None:
|
||||||
|
try:
|
||||||
|
await self._client.aclose()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
async def _pump(self, on_message: MessageHandler) -> None:
|
||||||
|
if self._pubsub is None:
|
||||||
|
return
|
||||||
|
backoff = _PUMP_RECONNECT_BACKOFF_INITIAL_S
|
||||||
|
deadline = time.monotonic() + _PUMP_RECONNECT_DEADLINE_S
|
||||||
|
while True:
|
||||||
|
pubsub = self._pubsub
|
||||||
|
if pubsub is None:
|
||||||
|
return
|
||||||
|
needs_reconnect = False
|
||||||
|
try:
|
||||||
|
async for message in pubsub.listen():
|
||||||
|
msg_type = message.get("type")
|
||||||
|
# Server-pushed sunsubscribe: slot ownership changed and
|
||||||
|
# Redis revoked our SSUBSCRIBE without dropping the TCP.
|
||||||
|
# Treat as a reconnect trigger so we re-resolve the shard.
|
||||||
|
if msg_type == "sunsubscribe":
|
||||||
|
needs_reconnect = True
|
||||||
|
break
|
||||||
|
if msg_type not in ("smessage", "message", "pmessage"):
|
||||||
|
continue
|
||||||
|
# Successful read resets the reconnect budget.
|
||||||
|
backoff = _PUMP_RECONNECT_BACKOFF_INITIAL_S
|
||||||
|
deadline = time.monotonic() + _PUMP_RECONNECT_DEADLINE_S
|
||||||
|
try:
|
||||||
|
await on_message(message.get("data"))
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Websocket message-handler failed for channel %s",
|
||||||
|
self.full_channel,
|
||||||
|
)
|
||||||
|
if not needs_reconnect:
|
||||||
|
# listen() exited cleanly (channels emptied) — pump is done.
|
||||||
|
return
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except (ConnectionError, RedisError) as exc:
|
||||||
|
if isinstance(exc, ResponseError) and not _is_moved_error(exc):
|
||||||
|
logger.exception(
|
||||||
|
"Pubsub pump crashed on non-retryable ResponseError for %s",
|
||||||
|
self.full_channel,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if time.monotonic() > deadline:
|
||||||
|
logger.exception(
|
||||||
|
"Pubsub pump giving up after reconnect deadline for %s",
|
||||||
|
self.full_channel,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
logger.warning(
|
||||||
|
"Pubsub pump reconnecting for %s after %s: %s",
|
||||||
|
self.full_channel,
|
||||||
|
type(exc).__name__,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Pubsub pump crashed for %s", self.full_channel)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Either a retryable error was raised, or the server pushed a
|
||||||
|
# sunsubscribe — close the stale pubsub and reopen against the
|
||||||
|
# (possibly migrated) shard.
|
||||||
|
await self._close_pubsub_quietly()
|
||||||
|
await asyncio.sleep(backoff)
|
||||||
|
backoff = min(backoff * 2, _PUMP_RECONNECT_BACKOFF_MAX_S)
|
||||||
|
try:
|
||||||
|
await self._open_pubsub()
|
||||||
|
except (ConnectionError, RedisError) as reopen_exc:
|
||||||
|
logger.warning(
|
||||||
|
"Pubsub pump reopen failed for %s: %s",
|
||||||
|
self.full_channel,
|
||||||
|
reopen_exc,
|
||||||
|
)
|
||||||
|
# Loop again — deadline check will eventually exit.
|
||||||
|
continue
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
if self._task is not None:
|
||||||
|
self._task.cancel()
|
||||||
|
try:
|
||||||
|
await self._task
|
||||||
|
except (asyncio.CancelledError, Exception):
|
||||||
|
pass
|
||||||
|
self._task = None
|
||||||
|
if self._pubsub is not None:
|
||||||
|
try:
|
||||||
|
await self._pubsub.execute_command("SUNSUBSCRIBE", self.full_channel)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"SUNSUBSCRIBE failed for %s", self.full_channel, exc_info=True
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await self._pubsub.aclose()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._pubsub = None
|
||||||
|
if self._client is not None:
|
||||||
|
try:
|
||||||
|
await self._client.aclose()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
|
||||||
class ConnectionManager:
|
class ConnectionManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.active_connections: Set[WebSocket] = set()
|
self.active_connections: Set[WebSocket] = set()
|
||||||
|
# channel_key → sockets subscribed (public channel keys, not raw Redis channels)
|
||||||
self.subscriptions: Dict[str, Set[WebSocket]] = {}
|
self.subscriptions: Dict[str, Set[WebSocket]] = {}
|
||||||
self.user_connections: Dict[str, Set[WebSocket]] = {}
|
# websocket → {channel_key: _Subscription}
|
||||||
|
self._ws_subs: Dict[WebSocket, Dict[str, _Subscription]] = {}
|
||||||
|
# websocket → notification subscription
|
||||||
|
self._ws_notifications: Dict[WebSocket, _Subscription] = {}
|
||||||
|
|
||||||
async def connect_socket(self, websocket: WebSocket, *, user_id: str):
|
async def connect_socket(self, websocket: WebSocket, *, user_id: str):
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
self.active_connections.add(websocket)
|
self.active_connections.add(websocket)
|
||||||
if user_id not in self.user_connections:
|
self._ws_subs.setdefault(websocket, {})
|
||||||
self.user_connections[user_id] = set()
|
await self._start_notification_subscription(websocket, user_id=user_id)
|
||||||
self.user_connections[user_id].add(websocket)
|
|
||||||
|
|
||||||
def disconnect_socket(self, websocket: WebSocket, *, user_id: str):
|
async def disconnect_socket(self, websocket: WebSocket, *, user_id: str):
|
||||||
self.active_connections.discard(websocket)
|
self.active_connections.discard(websocket)
|
||||||
for subscribers in self.subscriptions.values():
|
# Stop SSUBSCRIBE pumps before dropping bookkeeping to avoid leaks.
|
||||||
|
subs = self._ws_subs.pop(websocket, {})
|
||||||
|
for sub in subs.values():
|
||||||
|
await sub.stop()
|
||||||
|
notif_sub = self._ws_notifications.pop(websocket, None)
|
||||||
|
if notif_sub is not None:
|
||||||
|
await notif_sub.stop()
|
||||||
|
for channel_key, subscribers in list(self.subscriptions.items()):
|
||||||
subscribers.discard(websocket)
|
subscribers.discard(websocket)
|
||||||
user_conns = self.user_connections.get(user_id)
|
if not subscribers:
|
||||||
if user_conns is not None:
|
self.subscriptions.pop(channel_key, None)
|
||||||
user_conns.discard(websocket)
|
|
||||||
if not user_conns:
|
|
||||||
self.user_connections.pop(user_id, None)
|
|
||||||
|
|
||||||
async def subscribe_graph_exec(
|
async def subscribe_graph_exec(
|
||||||
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
|
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
|
||||||
) -> str:
|
) -> str:
|
||||||
return await self._subscribe(
|
# Hash-tagged channel needs graph_id; resolve once per subscribe.
|
||||||
_graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id), websocket
|
meta = await get_graph_execution_meta(user_id, graph_exec_id)
|
||||||
|
if meta is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"graph_exec #{graph_exec_id} not found for user #{user_id}"
|
||||||
|
)
|
||||||
|
channel_key = graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id)
|
||||||
|
full_channel = event_bus_channel(
|
||||||
|
exec_channel(user_id, meta.graph_id, graph_exec_id)
|
||||||
)
|
)
|
||||||
|
await self._open_subscription(websocket, channel_key, full_channel)
|
||||||
|
return channel_key
|
||||||
|
|
||||||
async def subscribe_graph_execs(
|
async def subscribe_graph_execs(
|
||||||
self, *, user_id: str, graph_id: str, websocket: WebSocket
|
self, *, user_id: str, graph_id: str, websocket: WebSocket
|
||||||
) -> str:
|
) -> str:
|
||||||
return await self._subscribe(
|
channel_key = _graph_execs_channel_key(user_id, graph_id=graph_id)
|
||||||
_graph_execs_channel_key(user_id, graph_id=graph_id), websocket
|
full_channel = event_bus_channel(graph_all_channel(user_id, graph_id))
|
||||||
)
|
await self._open_subscription(websocket, channel_key, full_channel)
|
||||||
|
return channel_key
|
||||||
|
|
||||||
async def unsubscribe_graph_exec(
|
async def unsubscribe_graph_exec(
|
||||||
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
|
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
return await self._unsubscribe(
|
channel_key = graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id)
|
||||||
_graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id), websocket
|
return await self._close_subscription(websocket, channel_key)
|
||||||
)
|
|
||||||
|
|
||||||
async def unsubscribe_graph_execs(
|
async def unsubscribe_graph_execs(
|
||||||
self, *, user_id: str, graph_id: str, websocket: WebSocket
|
self, *, user_id: str, graph_id: str, websocket: WebSocket
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
return await self._unsubscribe(
|
channel_key = _graph_execs_channel_key(user_id, graph_id=graph_id)
|
||||||
_graph_execs_channel_key(user_id, graph_id=graph_id), websocket
|
return await self._close_subscription(websocket, channel_key)
|
||||||
)
|
|
||||||
|
|
||||||
async def send_execution_update(
|
async def _open_subscription(
|
||||||
self, exec_event: GraphExecutionEvent | NodeExecutionEvent
|
self, websocket: WebSocket, channel_key: str, full_channel: str
|
||||||
) -> int:
|
) -> None:
|
||||||
graph_exec_id = (
|
self.subscriptions.setdefault(channel_key, set()).add(websocket)
|
||||||
exec_event.id
|
per_ws = self._ws_subs.setdefault(websocket, {})
|
||||||
if isinstance(exec_event, GraphExecutionEvent)
|
if channel_key in per_ws:
|
||||||
else exec_event.graph_exec_id
|
return
|
||||||
)
|
sub = _Subscription(full_channel)
|
||||||
|
|
||||||
n_sent = 0
|
async def on_message(data: Optional[bytes | str]) -> None:
|
||||||
|
await self._forward_exec_event(websocket, channel_key, data)
|
||||||
|
|
||||||
channels: set[str] = {
|
await sub.start(on_message)
|
||||||
# Send update to listeners for this graph execution
|
per_ws[channel_key] = sub
|
||||||
_graph_exec_channel_key(exec_event.user_id, graph_exec_id=graph_exec_id)
|
|
||||||
}
|
|
||||||
if isinstance(exec_event, GraphExecutionEvent):
|
|
||||||
# Send update to listeners for all executions of this graph
|
|
||||||
channels.add(
|
|
||||||
_graph_execs_channel_key(
|
|
||||||
exec_event.user_id, graph_id=exec_event.graph_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for channel in channels.intersection(self.subscriptions.keys()):
|
async def _close_subscription(
|
||||||
message = WSMessage(
|
self, websocket: WebSocket, channel_key: str
|
||||||
method=_EVENT_TYPE_TO_METHOD_MAP[exec_event.event_type],
|
) -> str | None:
|
||||||
channel=channel,
|
subscribers = self.subscriptions.get(channel_key)
|
||||||
data=exec_event.model_dump(),
|
if subscribers is None:
|
||||||
).model_dump_json()
|
return None
|
||||||
for connection in self.subscriptions[channel]:
|
subscribers.discard(websocket)
|
||||||
await connection.send_text(message)
|
if not subscribers:
|
||||||
n_sent += 1
|
self.subscriptions.pop(channel_key, None)
|
||||||
|
per_ws = self._ws_subs.get(websocket)
|
||||||
return n_sent
|
if per_ws and channel_key in per_ws:
|
||||||
|
sub = per_ws.pop(channel_key)
|
||||||
async def send_notification(
|
await sub.stop()
|
||||||
self, *, user_id: str, payload: NotificationPayload
|
|
||||||
) -> int:
|
|
||||||
"""Send a notification to all websocket connections belonging to a user."""
|
|
||||||
message = WSMessage(
|
|
||||||
method=WSMethod.NOTIFICATION,
|
|
||||||
data=payload.model_dump(),
|
|
||||||
).model_dump_json()
|
|
||||||
|
|
||||||
connections = tuple(self.user_connections.get(user_id, set()))
|
|
||||||
if not connections:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
await asyncio.gather(
|
|
||||||
*(connection.send_text(message) for connection in connections),
|
|
||||||
return_exceptions=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return len(connections)
|
|
||||||
|
|
||||||
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
|
|
||||||
if channel_key not in self.subscriptions:
|
|
||||||
self.subscriptions[channel_key] = set()
|
|
||||||
self.subscriptions[channel_key].add(websocket)
|
|
||||||
return channel_key
|
return channel_key
|
||||||
|
|
||||||
async def _unsubscribe(self, channel_key: str, websocket: WebSocket) -> str | None:
|
async def _forward_exec_event(
|
||||||
if channel_key in self.subscriptions:
|
self,
|
||||||
self.subscriptions[channel_key].discard(websocket)
|
websocket: WebSocket,
|
||||||
if not self.subscriptions[channel_key]:
|
channel_key: str,
|
||||||
del self.subscriptions[channel_key]
|
raw_payload: Optional[bytes | str],
|
||||||
return channel_key
|
) -> None:
|
||||||
return None
|
if raw_payload is None:
|
||||||
|
return
|
||||||
|
# Unwrap the `_EventPayloadWrapper` envelope, then re-wrap as a WS message.
|
||||||
|
try:
|
||||||
|
wrapper = (
|
||||||
|
raw_payload.decode()
|
||||||
|
if isinstance(raw_payload, (bytes, bytearray))
|
||||||
|
else raw_payload
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to decode pubsub payload on %s", channel_key, exc_info=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = json.loads(wrapper)
|
||||||
|
event_data = parsed.get("payload")
|
||||||
|
if not isinstance(event_data, dict):
|
||||||
|
return
|
||||||
|
event_type = event_data.get("event_type")
|
||||||
|
method = _EVENT_TYPE_TO_METHOD_MAP.get(ExecutionEventType(event_type))
|
||||||
|
if method is None:
|
||||||
|
return
|
||||||
|
message = WSMessage(
|
||||||
|
method=method,
|
||||||
|
channel=channel_key,
|
||||||
|
data=event_data,
|
||||||
|
).model_dump_json()
|
||||||
|
await websocket.send_text(message)
|
||||||
|
except Exception as e:
|
||||||
|
if _is_ws_close_race(e, websocket):
|
||||||
|
logger.debug("Dropped exec event on closed WS for %s", channel_key)
|
||||||
|
return
|
||||||
|
logger.exception("Failed to forward exec event on %s", channel_key)
|
||||||
|
|
||||||
|
async def _start_notification_subscription(
|
||||||
|
self, websocket: WebSocket, *, user_id: str
|
||||||
|
) -> None:
|
||||||
|
full_channel = _notification_bus_channel(user_id)
|
||||||
|
sub = _Subscription(full_channel)
|
||||||
|
|
||||||
|
async def on_message(data: Optional[bytes | str]) -> None:
|
||||||
|
await self._forward_notification(websocket, user_id, data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await sub.start(on_message)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to open notification SSUBSCRIBE for user=%s", user_id
|
||||||
|
)
|
||||||
|
return
|
||||||
|
self._ws_notifications[websocket] = sub
|
||||||
|
|
||||||
|
async def _forward_notification(
|
||||||
|
self,
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
raw_payload: Optional[bytes | str],
|
||||||
|
) -> None:
|
||||||
|
if raw_payload is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
wrapper_json = (
|
||||||
|
raw_payload.decode()
|
||||||
|
if isinstance(raw_payload, (bytes, bytearray))
|
||||||
|
else raw_payload
|
||||||
|
)
|
||||||
|
parsed = json.loads(wrapper_json)
|
||||||
|
inner = parsed.get("payload") if isinstance(parsed, dict) else None
|
||||||
|
if not isinstance(inner, dict):
|
||||||
|
return
|
||||||
|
event = NotificationEvent.model_validate(inner)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to parse notification payload for user=%s",
|
||||||
|
user_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
# Defense in depth against cross-user payloads.
|
||||||
|
if event.user_id != user_id:
|
||||||
|
return
|
||||||
|
message = WSMessage(
|
||||||
|
method=WSMethod.NOTIFICATION,
|
||||||
|
data=event.payload.model_dump(),
|
||||||
|
).model_dump_json()
|
||||||
|
try:
|
||||||
|
await websocket.send_text(message)
|
||||||
|
except Exception as e:
|
||||||
|
if _is_ws_close_race(e, websocket):
|
||||||
|
logger.debug("Dropped notification on closed WS for user=%s", user_id)
|
||||||
|
return
|
||||||
|
logger.warning(
|
||||||
|
"Failed to deliver notification to WS for user=%s",
|
||||||
|
user_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _graph_exec_channel_key(user_id: str, *, graph_exec_id: str) -> str:
|
def graph_exec_channel_key(user_id: str, *, graph_exec_id: str) -> str:
|
||||||
return f"{user_id}|graph_exec#{graph_exec_id}"
|
return f"{user_id}|graph_exec#{graph_exec_id}"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,386 @@
|
|||||||
|
"""ConnectionManager integration over the live 3-shard Redis cluster:
|
||||||
|
SSUBSCRIBE → SPUBLISH → WebSocket forwarding with no Redis mocks. Skips
|
||||||
|
when the cluster is unreachable."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import WebSocket
|
||||||
|
|
||||||
|
import backend.data.redis_client as redis_client
|
||||||
|
from backend.api.conn_manager import (
|
||||||
|
ConnectionManager,
|
||||||
|
_graph_execs_channel_key,
|
||||||
|
event_bus_channel,
|
||||||
|
graph_exec_channel_key,
|
||||||
|
)
|
||||||
|
from backend.api.model import WSMethod
|
||||||
|
from backend.data.execution import (
|
||||||
|
ExecutionStatus,
|
||||||
|
GraphExecutionEvent,
|
||||||
|
GraphExecutionMeta,
|
||||||
|
NodeExecutionEvent,
|
||||||
|
exec_channel,
|
||||||
|
graph_all_channel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _has_live_cluster() -> bool:
|
||||||
|
try:
|
||||||
|
c = redis_client.connect()
|
||||||
|
except Exception: # noqa: BLE001 — any connect failure → skip
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
c.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.skipif(
|
||||||
|
not _has_live_cluster(),
|
||||||
|
reason="local redis cluster not reachable; skip conn_manager integration",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _meta(user_id: str, graph_id: str, graph_exec_id: str) -> GraphExecutionMeta:
|
||||||
|
"""Build a minimal GraphExecutionMeta for ``subscribe_graph_exec`` to use."""
|
||||||
|
return GraphExecutionMeta(
|
||||||
|
id=graph_exec_id,
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_version=1,
|
||||||
|
inputs=None,
|
||||||
|
credential_inputs=None,
|
||||||
|
nodes_input_masks=None,
|
||||||
|
preset_id=None,
|
||||||
|
status=ExecutionStatus.RUNNING,
|
||||||
|
started_at=datetime.now(tz=timezone.utc),
|
||||||
|
ended_at=None,
|
||||||
|
stats=GraphExecutionMeta.Stats(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _node_event_payload(
|
||||||
|
*, user_id: str, graph_id: str, graph_exec_id: str, marker: str
|
||||||
|
) -> bytes:
|
||||||
|
"""Wire-format a NodeExecutionEvent the way RedisExecutionEventBus would."""
|
||||||
|
inner = NodeExecutionEvent(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_version=1,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
node_exec_id=f"node-exec-{marker}",
|
||||||
|
node_id="node-1",
|
||||||
|
block_id="block-1",
|
||||||
|
status=ExecutionStatus.COMPLETED,
|
||||||
|
input_data={"in": marker},
|
||||||
|
output_data={"out": [marker]},
|
||||||
|
add_time=datetime.now(tz=timezone.utc),
|
||||||
|
queue_time=None,
|
||||||
|
start_time=datetime.now(tz=timezone.utc),
|
||||||
|
end_time=datetime.now(tz=timezone.utc),
|
||||||
|
).model_dump(mode="json")
|
||||||
|
return json.dumps({"payload": inner}).encode()
|
||||||
|
|
||||||
|
|
||||||
|
def _graph_event_payload(
|
||||||
|
*, user_id: str, graph_id: str, graph_exec_id: str, marker: str
|
||||||
|
) -> bytes:
|
||||||
|
inner = GraphExecutionEvent(
|
||||||
|
id=graph_exec_id,
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_version=1,
|
||||||
|
preset_id=None,
|
||||||
|
status=ExecutionStatus.COMPLETED,
|
||||||
|
started_at=datetime.now(tz=timezone.utc),
|
||||||
|
ended_at=datetime.now(tz=timezone.utc),
|
||||||
|
stats=GraphExecutionEvent.Stats(
|
||||||
|
cost=0,
|
||||||
|
duration=1.0,
|
||||||
|
node_exec_time=0.5,
|
||||||
|
node_exec_count=1,
|
||||||
|
),
|
||||||
|
inputs={"x": marker},
|
||||||
|
credential_inputs=None,
|
||||||
|
nodes_input_masks=None,
|
||||||
|
outputs={"y": [marker]},
|
||||||
|
).model_dump(mode="json")
|
||||||
|
return json.dumps({"payload": inner}).encode()
|
||||||
|
|
||||||
|
|
||||||
|
async def _wait_until(predicate, timeout: float = 5.0, interval: float = 0.05) -> bool:
|
||||||
|
"""Poll ``predicate()`` until truthy or timeout — used to wait for pubsub."""
|
||||||
|
deadline = asyncio.get_event_loop().time() + timeout
|
||||||
|
while asyncio.get_event_loop().time() < deadline:
|
||||||
|
if predicate():
|
||||||
|
return True
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_two_clients_get_independent_ssubscribes_on_right_shards(
|
||||||
|
monkeypatch,
|
||||||
|
) -> None:
|
||||||
|
"""Two WS clients on different graph_exec_ids each receive ONLY their
|
||||||
|
own publish, even when the channels land on different shards."""
|
||||||
|
user_id = "user-conn-int-1"
|
||||||
|
graph_a = f"graph-a-{uuid4().hex[:8]}"
|
||||||
|
graph_b = f"graph-b-{uuid4().hex[:8]}"
|
||||||
|
exec_a = f"exec-a-{uuid4().hex[:8]}"
|
||||||
|
exec_b = f"exec-b-{uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
# Stub Prisma lookup so tests don't need a DB.
|
||||||
|
async def _fake_meta(_uid, gex_id):
|
||||||
|
return _meta(user_id, graph_a if gex_id == exec_a else graph_b, gex_id)
|
||||||
|
|
||||||
|
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
|
||||||
|
|
||||||
|
cm = ConnectionManager()
|
||||||
|
ws_a: AsyncMock = AsyncMock(spec=WebSocket)
|
||||||
|
ws_b: AsyncMock = AsyncMock(spec=WebSocket)
|
||||||
|
sent_a: list[str] = []
|
||||||
|
sent_b: list[str] = []
|
||||||
|
ws_a.send_text = AsyncMock(side_effect=lambda m: sent_a.append(m))
|
||||||
|
ws_b.send_text = AsyncMock(side_effect=lambda m: sent_b.append(m))
|
||||||
|
|
||||||
|
redis_client.get_redis.cache_clear()
|
||||||
|
cluster = redis_client.get_redis()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cm.subscribe_graph_exec(
|
||||||
|
user_id=user_id, graph_exec_id=exec_a, websocket=ws_a
|
||||||
|
)
|
||||||
|
await cm.subscribe_graph_exec(
|
||||||
|
user_id=user_id, graph_exec_id=exec_b, websocket=ws_b
|
||||||
|
)
|
||||||
|
# Let SSUBSCRIBE settle on each shard.
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
|
# Publish to each per-exec channel.
|
||||||
|
chan_a = event_bus_channel(exec_channel(user_id, graph_a, exec_a))
|
||||||
|
chan_b = event_bus_channel(exec_channel(user_id, graph_b, exec_b))
|
||||||
|
cluster.spublish(
|
||||||
|
chan_a,
|
||||||
|
_node_event_payload(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=graph_a,
|
||||||
|
graph_exec_id=exec_a,
|
||||||
|
marker="A",
|
||||||
|
).decode(),
|
||||||
|
)
|
||||||
|
cluster.spublish(
|
||||||
|
chan_b,
|
||||||
|
_node_event_payload(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=graph_b,
|
||||||
|
graph_exec_id=exec_b,
|
||||||
|
marker="B",
|
||||||
|
).decode(),
|
||||||
|
)
|
||||||
|
|
||||||
|
delivered = await _wait_until(lambda: sent_a and sent_b, timeout=5.0)
|
||||||
|
assert delivered, f"timeout: sent_a={sent_a!r} sent_b={sent_b!r}"
|
||||||
|
|
||||||
|
msg_a = json.loads(sent_a[0])
|
||||||
|
msg_b = json.loads(sent_b[0])
|
||||||
|
assert msg_a["channel"] == graph_exec_channel_key(user_id, graph_exec_id=exec_a)
|
||||||
|
assert msg_b["channel"] == graph_exec_channel_key(user_id, graph_exec_id=exec_b)
|
||||||
|
assert msg_a["data"]["graph_exec_id"] == exec_a
|
||||||
|
assert msg_b["data"]["graph_exec_id"] == exec_b
|
||||||
|
# No cross-talk: each socket got exactly one message.
|
||||||
|
assert len(sent_a) == 1 and len(sent_b) == 1
|
||||||
|
finally:
|
||||||
|
await cm.disconnect_socket(ws_a, user_id=user_id)
|
||||||
|
await cm.disconnect_socket(ws_b, user_id=user_id)
|
||||||
|
redis_client.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aggregate_channel_receives_per_exec_publishes(monkeypatch) -> None:
|
||||||
|
"""A subscriber on the ``graph_execs`` aggregate channel must receive the
|
||||||
|
GraphExecutionEvent published to the ``/all`` channel — even though
|
||||||
|
per-exec events go to a different channel."""
|
||||||
|
user_id = "user-conn-int-2"
|
||||||
|
graph_id = f"graph-{uuid4().hex[:8]}"
|
||||||
|
exec_id = f"exec-{uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
async def _fake_meta(_uid, gex_id):
|
||||||
|
return _meta(user_id, graph_id, gex_id)
|
||||||
|
|
||||||
|
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
|
||||||
|
|
||||||
|
cm = ConnectionManager()
|
||||||
|
ws_agg: AsyncMock = AsyncMock(spec=WebSocket)
|
||||||
|
ws_per: AsyncMock = AsyncMock(spec=WebSocket)
|
||||||
|
sent_agg: list[str] = []
|
||||||
|
sent_per: list[str] = []
|
||||||
|
ws_agg.send_text = AsyncMock(side_effect=lambda m: sent_agg.append(m))
|
||||||
|
ws_per.send_text = AsyncMock(side_effect=lambda m: sent_per.append(m))
|
||||||
|
|
||||||
|
redis_client.get_redis.cache_clear()
|
||||||
|
cluster = redis_client.get_redis()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cm.subscribe_graph_execs(
|
||||||
|
user_id=user_id, graph_id=graph_id, websocket=ws_agg
|
||||||
|
)
|
||||||
|
await cm.subscribe_graph_exec(
|
||||||
|
user_id=user_id, graph_exec_id=exec_id, websocket=ws_per
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
|
# The eventbus publishes the same event to both channels — replicate.
|
||||||
|
chan_per = event_bus_channel(exec_channel(user_id, graph_id, exec_id))
|
||||||
|
chan_all = event_bus_channel(graph_all_channel(user_id, graph_id))
|
||||||
|
payload = _graph_event_payload(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_exec_id=exec_id,
|
||||||
|
marker="agg",
|
||||||
|
).decode()
|
||||||
|
cluster.spublish(chan_per, payload)
|
||||||
|
cluster.spublish(chan_all, payload)
|
||||||
|
|
||||||
|
delivered = await _wait_until(lambda: sent_agg and sent_per, timeout=5.0)
|
||||||
|
assert delivered, f"sent_agg={sent_agg!r} sent_per={sent_per!r}"
|
||||||
|
agg_msg = json.loads(sent_agg[0])
|
||||||
|
per_msg = json.loads(sent_per[0])
|
||||||
|
# Aggregate subscriber's channel key is the per-graph executions key.
|
||||||
|
assert agg_msg["channel"] == _graph_execs_channel_key(
|
||||||
|
user_id, graph_id=graph_id
|
||||||
|
)
|
||||||
|
assert per_msg["channel"] == graph_exec_channel_key(
|
||||||
|
user_id, graph_exec_id=exec_id
|
||||||
|
)
|
||||||
|
assert agg_msg["method"] == WSMethod.GRAPH_EXECUTION_EVENT.value
|
||||||
|
finally:
|
||||||
|
await cm.disconnect_socket(ws_agg, user_id=user_id)
|
||||||
|
await cm.disconnect_socket(ws_per, user_id=user_id)
|
||||||
|
redis_client.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_disconnect_unsubscribes_and_drops_future_publishes(monkeypatch) -> None:
|
||||||
|
"""After ``disconnect_socket`` runs, a subsequent SPUBLISH must NOT reach
|
||||||
|
the dead websocket — exercises the SUNSUBSCRIBE + bookkeeping cleanup."""
|
||||||
|
user_id = "user-conn-int-3"
|
||||||
|
graph_id = f"graph-{uuid4().hex[:8]}"
|
||||||
|
exec_id = f"exec-{uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
async def _fake_meta(_uid, gex_id):
|
||||||
|
return _meta(user_id, graph_id, gex_id)
|
||||||
|
|
||||||
|
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
|
||||||
|
|
||||||
|
cm = ConnectionManager()
|
||||||
|
ws: AsyncMock = AsyncMock(spec=WebSocket)
|
||||||
|
sent: list[str] = []
|
||||||
|
ws.send_text = AsyncMock(side_effect=lambda m: sent.append(m))
|
||||||
|
|
||||||
|
redis_client.get_redis.cache_clear()
|
||||||
|
cluster = redis_client.get_redis()
|
||||||
|
chan = event_bus_channel(exec_channel(user_id, graph_id, exec_id))
|
||||||
|
payload = _node_event_payload(
|
||||||
|
user_id=user_id, graph_id=graph_id, graph_exec_id=exec_id, marker="live"
|
||||||
|
).decode()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cm.subscribe_graph_exec(
|
||||||
|
user_id=user_id, graph_exec_id=exec_id, websocket=ws
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
|
||||||
|
# First publish — must reach the socket.
|
||||||
|
cluster.spublish(chan, payload)
|
||||||
|
delivered = await _wait_until(lambda: bool(sent), timeout=5.0)
|
||||||
|
assert delivered
|
||||||
|
assert len(sent) == 1
|
||||||
|
|
||||||
|
# Disconnect → SUNSUBSCRIBE + bookkeeping cleared.
|
||||||
|
await cm.disconnect_socket(ws, user_id=user_id)
|
||||||
|
# Pump cancellation may drain in-flight messages; wait for it.
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
|
# Channel bookkeeping must be gone.
|
||||||
|
assert (
|
||||||
|
graph_exec_channel_key(user_id, graph_exec_id=exec_id)
|
||||||
|
not in cm.subscriptions
|
||||||
|
)
|
||||||
|
assert ws not in cm._ws_subs
|
||||||
|
|
||||||
|
# Second publish — must NOT reach the (already-disconnected) socket.
|
||||||
|
cluster.spublish(
|
||||||
|
chan,
|
||||||
|
_node_event_payload(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_exec_id=exec_id,
|
||||||
|
marker="post-disconnect",
|
||||||
|
).decode(),
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
# Still only the one pre-disconnect message.
|
||||||
|
assert len(sent) == 1
|
||||||
|
finally:
|
||||||
|
redis_client.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_slow_consumer_receives_all_events_without_loss(monkeypatch) -> None:
|
||||||
|
"""Burst-publish many SPUBLISHes; assert every one reaches the subscriber
|
||||||
|
in order — guards against drops/reorderings in the pubsub pump."""
|
||||||
|
user_id = "user-conn-int-4"
|
||||||
|
graph_id = f"graph-{uuid4().hex[:8]}"
|
||||||
|
exec_id = f"exec-{uuid4().hex[:8]}"
|
||||||
|
n_events = 100
|
||||||
|
|
||||||
|
async def _fake_meta(_uid, gex_id):
|
||||||
|
return _meta(user_id, graph_id, gex_id)
|
||||||
|
|
||||||
|
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
|
||||||
|
|
||||||
|
cm = ConnectionManager()
|
||||||
|
ws: AsyncMock = AsyncMock(spec=WebSocket)
|
||||||
|
sent: list[str] = []
|
||||||
|
ws.send_text = AsyncMock(side_effect=lambda m: sent.append(m))
|
||||||
|
|
||||||
|
redis_client.get_redis.cache_clear()
|
||||||
|
cluster = redis_client.get_redis()
|
||||||
|
chan = event_bus_channel(exec_channel(user_id, graph_id, exec_id))
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cm.subscribe_graph_exec(
|
||||||
|
user_id=user_id, graph_exec_id=exec_id, websocket=ws
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
|
# Burst-publish n_events without yielding to the pump.
|
||||||
|
for i in range(n_events):
|
||||||
|
cluster.spublish(
|
||||||
|
chan,
|
||||||
|
_node_event_payload(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_exec_id=exec_id,
|
||||||
|
marker=f"m{i}",
|
||||||
|
).decode(),
|
||||||
|
)
|
||||||
|
|
||||||
|
delivered = await _wait_until(
|
||||||
|
lambda: len(sent) >= n_events, timeout=15.0, interval=0.1
|
||||||
|
)
|
||||||
|
assert delivered, f"only delivered {len(sent)}/{n_events}"
|
||||||
|
|
||||||
|
# Validate ordering — Redis pub/sub is FIFO per channel.
|
||||||
|
markers = [json.loads(m)["data"]["input_data"]["in"] for m in sent[:n_events]]
|
||||||
|
assert markers == [f"m{i}" for i in range(n_events)]
|
||||||
|
finally:
|
||||||
|
await cm.disconnect_socket(ws, user_id=user_id)
|
||||||
|
redis_client.disconnect()
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,932 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from autogpt_libs.auth import requires_admin_user
|
||||||
|
from autogpt_libs.auth.models import User as AuthUser
|
||||||
|
from fastapi import APIRouter, HTTPException, Security
|
||||||
|
from prisma.enums import AgentExecutionStatus
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from backend.api.features.admin.model import (
|
||||||
|
AgentDiagnosticsResponse,
|
||||||
|
ExecutionDiagnosticsResponse,
|
||||||
|
)
|
||||||
|
from backend.data.diagnostics import (
|
||||||
|
FailedExecutionDetail,
|
||||||
|
OrphanedScheduleDetail,
|
||||||
|
RunningExecutionDetail,
|
||||||
|
ScheduleDetail,
|
||||||
|
ScheduleHealthMetrics,
|
||||||
|
cleanup_all_stuck_queued_executions,
|
||||||
|
cleanup_orphaned_executions_bulk,
|
||||||
|
cleanup_orphaned_schedules_bulk,
|
||||||
|
get_agent_diagnostics,
|
||||||
|
get_all_orphaned_execution_ids,
|
||||||
|
get_all_schedules_details,
|
||||||
|
get_all_stuck_queued_execution_ids,
|
||||||
|
get_execution_diagnostics,
|
||||||
|
get_failed_executions_count,
|
||||||
|
get_failed_executions_details,
|
||||||
|
get_invalid_executions_details,
|
||||||
|
get_long_running_executions_details,
|
||||||
|
get_orphaned_executions_details,
|
||||||
|
get_orphaned_schedules_details,
|
||||||
|
get_running_executions_details,
|
||||||
|
get_schedule_health_metrics,
|
||||||
|
get_stuck_queued_executions_details,
|
||||||
|
stop_all_long_running_executions,
|
||||||
|
)
|
||||||
|
from backend.data.execution import get_graph_executions
|
||||||
|
from backend.executor.utils import add_graph_execution, stop_graph_execution
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/admin",
|
||||||
|
tags=["diagnostics", "admin"],
|
||||||
|
dependencies=[Security(requires_admin_user)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RunningExecutionsListResponse(BaseModel):
|
||||||
|
"""Response model for list of running executions"""
|
||||||
|
|
||||||
|
executions: List[RunningExecutionDetail]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
class FailedExecutionsListResponse(BaseModel):
|
||||||
|
"""Response model for list of failed executions"""
|
||||||
|
|
||||||
|
executions: List[FailedExecutionDetail]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
class StopExecutionRequest(BaseModel):
|
||||||
|
"""Request model for stopping a single execution"""
|
||||||
|
|
||||||
|
execution_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class StopExecutionsRequest(BaseModel):
|
||||||
|
"""Request model for stopping multiple executions"""
|
||||||
|
|
||||||
|
execution_ids: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class StopExecutionResponse(BaseModel):
|
||||||
|
"""Response model for stop execution operations"""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
stopped_count: int = 0
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class RequeueExecutionResponse(BaseModel):
|
||||||
|
"""Response model for requeue execution operations"""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
requeued_count: int = 0
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/diagnostics/executions",
|
||||||
|
response_model=ExecutionDiagnosticsResponse,
|
||||||
|
summary="Get Execution Diagnostics",
|
||||||
|
)
|
||||||
|
async def get_execution_diagnostics_endpoint():
|
||||||
|
"""
|
||||||
|
Get comprehensive diagnostic information about execution status.
|
||||||
|
|
||||||
|
Returns all execution metrics including:
|
||||||
|
- Current state (running, queued)
|
||||||
|
- Orphaned executions (>24h old, likely not in executor)
|
||||||
|
- Failure metrics (1h, 24h, rate)
|
||||||
|
- Long-running detection (stuck >1h, >24h)
|
||||||
|
- Stuck queued detection
|
||||||
|
- Throughput metrics (completions/hour)
|
||||||
|
- RabbitMQ queue depths
|
||||||
|
"""
|
||||||
|
logger.info("Getting execution diagnostics")
|
||||||
|
|
||||||
|
diagnostics = await get_execution_diagnostics()
|
||||||
|
|
||||||
|
response = ExecutionDiagnosticsResponse(
|
||||||
|
running_executions=diagnostics.running_count,
|
||||||
|
queued_executions_db=diagnostics.queued_db_count,
|
||||||
|
queued_executions_rabbitmq=diagnostics.rabbitmq_queue_depth,
|
||||||
|
cancel_queue_depth=diagnostics.cancel_queue_depth,
|
||||||
|
orphaned_running=diagnostics.orphaned_running,
|
||||||
|
orphaned_queued=diagnostics.orphaned_queued,
|
||||||
|
failed_count_1h=diagnostics.failed_count_1h,
|
||||||
|
failed_count_24h=diagnostics.failed_count_24h,
|
||||||
|
failure_rate_24h=diagnostics.failure_rate_24h,
|
||||||
|
stuck_running_24h=diagnostics.stuck_running_24h,
|
||||||
|
stuck_running_1h=diagnostics.stuck_running_1h,
|
||||||
|
oldest_running_hours=diagnostics.oldest_running_hours,
|
||||||
|
stuck_queued_1h=diagnostics.stuck_queued_1h,
|
||||||
|
queued_never_started=diagnostics.queued_never_started,
|
||||||
|
invalid_queued_with_start=diagnostics.invalid_queued_with_start,
|
||||||
|
invalid_running_without_start=diagnostics.invalid_running_without_start,
|
||||||
|
completed_1h=diagnostics.completed_1h,
|
||||||
|
completed_24h=diagnostics.completed_24h,
|
||||||
|
throughput_per_hour=diagnostics.throughput_per_hour,
|
||||||
|
timestamp=diagnostics.timestamp,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Execution diagnostics: running={diagnostics.running_count}, "
|
||||||
|
f"queued_db={diagnostics.queued_db_count}, "
|
||||||
|
f"orphaned={diagnostics.orphaned_running + diagnostics.orphaned_queued}, "
|
||||||
|
f"failed_24h={diagnostics.failed_count_24h}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/diagnostics/agents",
|
||||||
|
response_model=AgentDiagnosticsResponse,
|
||||||
|
summary="Get Agent Diagnostics",
|
||||||
|
)
|
||||||
|
async def get_agent_diagnostics_endpoint():
|
||||||
|
"""
|
||||||
|
Get diagnostic information about agents.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- agents_with_active_executions: Number of unique agents with running/queued executions
|
||||||
|
- timestamp: Current timestamp
|
||||||
|
"""
|
||||||
|
logger.info("Getting agent diagnostics")
|
||||||
|
|
||||||
|
diagnostics = await get_agent_diagnostics()
|
||||||
|
|
||||||
|
response = AgentDiagnosticsResponse(
|
||||||
|
agents_with_active_executions=diagnostics.agents_with_active_executions,
|
||||||
|
timestamp=diagnostics.timestamp,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Agent diagnostics: with_active_executions={diagnostics.agents_with_active_executions}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/diagnostics/executions/running",
|
||||||
|
response_model=RunningExecutionsListResponse,
|
||||||
|
summary="List Running Executions",
|
||||||
|
)
|
||||||
|
async def list_running_executions(
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get detailed list of running and queued executions (recent, likely active).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number of executions to return (default 100)
|
||||||
|
offset: Number of executions to skip (default 0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of running executions with details
|
||||||
|
"""
|
||||||
|
logger.info(f"Listing running executions (limit={limit}, offset={offset})")
|
||||||
|
|
||||||
|
executions = await get_running_executions_details(limit=limit, offset=offset)
|
||||||
|
|
||||||
|
# Get total count for pagination
|
||||||
|
diagnostics = await get_execution_diagnostics()
|
||||||
|
total = diagnostics.running_count + diagnostics.queued_db_count
|
||||||
|
|
||||||
|
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/diagnostics/executions/orphaned",
|
||||||
|
response_model=RunningExecutionsListResponse,
|
||||||
|
summary="List Orphaned Executions",
|
||||||
|
)
|
||||||
|
async def list_orphaned_executions(
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get detailed list of orphaned executions (>24h old, likely not in executor).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number of executions to return (default 100)
|
||||||
|
offset: Number of executions to skip (default 0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of orphaned executions with details
|
||||||
|
"""
|
||||||
|
logger.info(f"Listing orphaned executions (limit={limit}, offset={offset})")
|
||||||
|
|
||||||
|
executions = await get_orphaned_executions_details(limit=limit, offset=offset)
|
||||||
|
|
||||||
|
# Get total count for pagination
|
||||||
|
diagnostics = await get_execution_diagnostics()
|
||||||
|
total = diagnostics.orphaned_running + diagnostics.orphaned_queued
|
||||||
|
|
||||||
|
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/diagnostics/executions/failed",
|
||||||
|
response_model=FailedExecutionsListResponse,
|
||||||
|
summary="List Failed Executions",
|
||||||
|
)
|
||||||
|
async def list_failed_executions(
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
hours: int = 24,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get detailed list of failed executions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number of executions to return (default 100)
|
||||||
|
offset: Number of executions to skip (default 0)
|
||||||
|
hours: Number of hours to look back (default 24)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of failed executions with error details
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Listing failed executions (limit={limit}, offset={offset}, hours={hours})"
|
||||||
|
)
|
||||||
|
|
||||||
|
executions = await get_failed_executions_details(
|
||||||
|
limit=limit, offset=offset, hours=hours
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get total count for pagination
|
||||||
|
# Always count actual total for given hours parameter
|
||||||
|
total = await get_failed_executions_count(hours=hours)
|
||||||
|
|
||||||
|
return FailedExecutionsListResponse(executions=executions, total=total)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/diagnostics/executions/long-running",
|
||||||
|
response_model=RunningExecutionsListResponse,
|
||||||
|
summary="List Long-Running Executions",
|
||||||
|
)
|
||||||
|
async def list_long_running_executions(
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get detailed list of long-running executions (RUNNING status >24h).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number of executions to return (default 100)
|
||||||
|
offset: Number of executions to skip (default 0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of long-running executions with details
|
||||||
|
"""
|
||||||
|
logger.info(f"Listing long-running executions (limit={limit}, offset={offset})")
|
||||||
|
|
||||||
|
executions = await get_long_running_executions_details(limit=limit, offset=offset)
|
||||||
|
|
||||||
|
# Get total count for pagination
|
||||||
|
diagnostics = await get_execution_diagnostics()
|
||||||
|
total = diagnostics.stuck_running_24h
|
||||||
|
|
||||||
|
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/diagnostics/executions/stuck-queued",
|
||||||
|
response_model=RunningExecutionsListResponse,
|
||||||
|
summary="List Stuck Queued Executions",
|
||||||
|
)
|
||||||
|
async def list_stuck_queued_executions(
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get detailed list of stuck queued executions (QUEUED >1h, never started).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number of executions to return (default 100)
|
||||||
|
offset: Number of executions to skip (default 0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of stuck queued executions with details
|
||||||
|
"""
|
||||||
|
logger.info(f"Listing stuck queued executions (limit={limit}, offset={offset})")
|
||||||
|
|
||||||
|
executions = await get_stuck_queued_executions_details(limit=limit, offset=offset)
|
||||||
|
|
||||||
|
# Get total count for pagination
|
||||||
|
diagnostics = await get_execution_diagnostics()
|
||||||
|
total = diagnostics.stuck_queued_1h
|
||||||
|
|
||||||
|
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/diagnostics/executions/invalid",
|
||||||
|
response_model=RunningExecutionsListResponse,
|
||||||
|
summary="List Invalid Executions",
|
||||||
|
)
|
||||||
|
async def list_invalid_executions(
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get detailed list of executions in invalid states (READ-ONLY).
|
||||||
|
|
||||||
|
Invalid states indicate data corruption and require manual investigation:
|
||||||
|
- QUEUED but has startedAt (impossible - can't start while queued)
|
||||||
|
- RUNNING but no startedAt (impossible - can't run without starting)
|
||||||
|
|
||||||
|
⚠️ NO BULK ACTIONS PROVIDED - These need case-by-case investigation.
|
||||||
|
|
||||||
|
Each invalid execution likely has a different root cause (crashes, race conditions,
|
||||||
|
DB corruption). Investigate the execution history and logs to determine appropriate
|
||||||
|
action (manual cleanup, status fix, or leave as-is if system recovered).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number of executions to return (default 100)
|
||||||
|
offset: Number of executions to skip (default 0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of invalid state executions with details
|
||||||
|
"""
|
||||||
|
logger.info(f"Listing invalid state executions (limit={limit}, offset={offset})")
|
||||||
|
|
||||||
|
executions = await get_invalid_executions_details(limit=limit, offset=offset)
|
||||||
|
|
||||||
|
# Get total count for pagination
|
||||||
|
diagnostics = await get_execution_diagnostics()
|
||||||
|
total = (
|
||||||
|
diagnostics.invalid_queued_with_start
|
||||||
|
+ diagnostics.invalid_running_without_start
|
||||||
|
)
|
||||||
|
|
||||||
|
return RunningExecutionsListResponse(executions=executions, total=total)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/diagnostics/executions/requeue",
|
||||||
|
response_model=RequeueExecutionResponse,
|
||||||
|
summary="Requeue Stuck Execution",
|
||||||
|
)
|
||||||
|
async def requeue_single_execution(
|
||||||
|
request: StopExecutionRequest, # Reuse same request model (has execution_id)
|
||||||
|
user: AuthUser = Security(requires_admin_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Requeue a stuck QUEUED execution (admin only).
|
||||||
|
|
||||||
|
Uses add_graph_execution with existing graph_exec_id to requeue.
|
||||||
|
|
||||||
|
⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Contains execution_id to requeue
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success status and message
|
||||||
|
"""
|
||||||
|
logger.info(f"Admin {user.user_id} requeueing execution {request.execution_id}")
|
||||||
|
|
||||||
|
# Get the execution (validation - must be QUEUED)
|
||||||
|
executions = await get_graph_executions(
|
||||||
|
graph_exec_id=request.execution_id,
|
||||||
|
statuses=[AgentExecutionStatus.QUEUED],
|
||||||
|
)
|
||||||
|
|
||||||
|
if not executions:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Execution not found or not in QUEUED status",
|
||||||
|
)
|
||||||
|
|
||||||
|
execution = executions[0]
|
||||||
|
|
||||||
|
# Use add_graph_execution in requeue mode
|
||||||
|
await add_graph_execution(
|
||||||
|
graph_id=execution.graph_id,
|
||||||
|
user_id=execution.user_id,
|
||||||
|
graph_version=execution.graph_version,
|
||||||
|
graph_exec_id=request.execution_id, # Requeue existing execution
|
||||||
|
)
|
||||||
|
|
||||||
|
return RequeueExecutionResponse(
|
||||||
|
success=True,
|
||||||
|
requeued_count=1,
|
||||||
|
message="Execution requeued successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/diagnostics/executions/requeue-bulk",
|
||||||
|
response_model=RequeueExecutionResponse,
|
||||||
|
summary="Requeue Multiple Stuck Executions",
|
||||||
|
)
|
||||||
|
async def requeue_multiple_executions(
|
||||||
|
request: StopExecutionsRequest, # Reuse same request model (has execution_ids)
|
||||||
|
user: AuthUser = Security(requires_admin_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Requeue multiple stuck QUEUED executions (admin only).
|
||||||
|
|
||||||
|
Uses add_graph_execution with existing graph_exec_id to requeue.
|
||||||
|
|
||||||
|
⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Contains list of execution_ids to requeue
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of executions requeued and success message
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Admin {user.user_id} requeueing {len(request.execution_ids)} executions"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get executions by ID list (must be QUEUED)
|
||||||
|
executions = await get_graph_executions(
|
||||||
|
execution_ids=request.execution_ids,
|
||||||
|
statuses=[AgentExecutionStatus.QUEUED],
|
||||||
|
)
|
||||||
|
|
||||||
|
if not executions:
|
||||||
|
return RequeueExecutionResponse(
|
||||||
|
success=False,
|
||||||
|
requeued_count=0,
|
||||||
|
message="No QUEUED executions found to requeue",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Requeue all executions in parallel using add_graph_execution
|
||||||
|
async def requeue_one(exec) -> bool:
|
||||||
|
try:
|
||||||
|
await add_graph_execution(
|
||||||
|
graph_id=exec.graph_id,
|
||||||
|
user_id=exec.user_id,
|
||||||
|
graph_version=exec.graph_version,
|
||||||
|
graph_exec_id=exec.id, # Requeue existing
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to requeue {exec.id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[requeue_one(exec) for exec in executions], return_exceptions=False
|
||||||
|
)
|
||||||
|
|
||||||
|
requeued_count = sum(1 for success in results if success)
|
||||||
|
|
||||||
|
return RequeueExecutionResponse(
|
||||||
|
success=requeued_count > 0,
|
||||||
|
requeued_count=requeued_count,
|
||||||
|
message=f"Requeued {requeued_count} of {len(request.execution_ids)} executions",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/diagnostics/executions/stop",
|
||||||
|
response_model=StopExecutionResponse,
|
||||||
|
summary="Stop Single Execution",
|
||||||
|
)
|
||||||
|
async def stop_single_execution(
|
||||||
|
request: StopExecutionRequest,
|
||||||
|
user: AuthUser = Security(requires_admin_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Stop a single execution (admin only).
|
||||||
|
|
||||||
|
Uses robust stop_graph_execution which cascades to children and waits for termination.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Contains execution_id to stop
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success status and message
|
||||||
|
"""
|
||||||
|
logger.info(f"Admin {user.user_id} stopping execution {request.execution_id}")
|
||||||
|
|
||||||
|
# Get the execution to find its owner user_id (required by stop_graph_execution)
|
||||||
|
executions = await get_graph_executions(
|
||||||
|
graph_exec_id=request.execution_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not executions:
|
||||||
|
raise HTTPException(status_code=404, detail="Execution not found")
|
||||||
|
|
||||||
|
execution = executions[0]
|
||||||
|
|
||||||
|
# Use robust stop_graph_execution (cascades to children, waits for termination)
|
||||||
|
await stop_graph_execution(
|
||||||
|
user_id=execution.user_id,
|
||||||
|
graph_exec_id=request.execution_id,
|
||||||
|
wait_timeout=15.0,
|
||||||
|
cascade=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return StopExecutionResponse(
|
||||||
|
success=True,
|
||||||
|
stopped_count=1,
|
||||||
|
message="Execution stopped successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/diagnostics/executions/stop-bulk",
|
||||||
|
response_model=StopExecutionResponse,
|
||||||
|
summary="Stop Multiple Executions",
|
||||||
|
)
|
||||||
|
async def stop_multiple_executions(
|
||||||
|
request: StopExecutionsRequest,
|
||||||
|
user: AuthUser = Security(requires_admin_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Stop multiple active executions (admin only).
|
||||||
|
|
||||||
|
Uses robust stop_graph_execution which cascades to children and waits for termination.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Contains list of execution_ids to stop
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of executions stopped and success message
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Admin {user.user_id} stopping {len(request.execution_ids)} executions"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get executions by ID list
|
||||||
|
executions = await get_graph_executions(
|
||||||
|
execution_ids=request.execution_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not executions:
|
||||||
|
return StopExecutionResponse(
|
||||||
|
success=False,
|
||||||
|
stopped_count=0,
|
||||||
|
message="No executions found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stop all executions in parallel using robust stop_graph_execution
|
||||||
|
async def stop_one(exec) -> bool:
|
||||||
|
try:
|
||||||
|
await stop_graph_execution(
|
||||||
|
user_id=exec.user_id,
|
||||||
|
graph_exec_id=exec.id,
|
||||||
|
wait_timeout=15.0,
|
||||||
|
cascade=True,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to stop execution {exec.id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[stop_one(exec) for exec in executions], return_exceptions=False
|
||||||
|
)
|
||||||
|
|
||||||
|
stopped_count = sum(1 for success in results if success)
|
||||||
|
|
||||||
|
return StopExecutionResponse(
|
||||||
|
success=stopped_count > 0,
|
||||||
|
stopped_count=stopped_count,
|
||||||
|
message=f"Stopped {stopped_count} of {len(request.execution_ids)} executions",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/diagnostics/executions/cleanup-orphaned",
|
||||||
|
response_model=StopExecutionResponse,
|
||||||
|
summary="Cleanup Orphaned Executions",
|
||||||
|
)
|
||||||
|
async def cleanup_orphaned_executions(
|
||||||
|
request: StopExecutionsRequest,
|
||||||
|
user: AuthUser = Security(requires_admin_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Cleanup orphaned executions by directly updating DB status (admin only).
|
||||||
|
For executions in DB but not actually running in executor (old/stale records).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Contains list of execution_ids to cleanup
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of executions cleaned up and success message
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Admin {user.user_id} cleaning up {len(request.execution_ids)} orphaned executions"
|
||||||
|
)
|
||||||
|
|
||||||
|
cleaned_count = await cleanup_orphaned_executions_bulk(
|
||||||
|
request.execution_ids, user.user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return StopExecutionResponse(
|
||||||
|
success=cleaned_count > 0,
|
||||||
|
stopped_count=cleaned_count,
|
||||||
|
message=f"Cleaned up {cleaned_count} of {len(request.execution_ids)} orphaned executions",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# SCHEDULE DIAGNOSTICS ENDPOINTS
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulesListResponse(BaseModel):
|
||||||
|
"""Response model for list of schedules"""
|
||||||
|
|
||||||
|
schedules: List[ScheduleDetail]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
class OrphanedSchedulesListResponse(BaseModel):
|
||||||
|
"""Response model for list of orphaned schedules"""
|
||||||
|
|
||||||
|
schedules: List[OrphanedScheduleDetail]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduleCleanupRequest(BaseModel):
|
||||||
|
"""Request model for cleaning up schedules"""
|
||||||
|
|
||||||
|
schedule_ids: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduleCleanupResponse(BaseModel):
|
||||||
|
"""Response model for schedule cleanup operations"""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
deleted_count: int = 0
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/diagnostics/schedules",
|
||||||
|
response_model=ScheduleHealthMetrics,
|
||||||
|
summary="Get Schedule Diagnostics",
|
||||||
|
)
|
||||||
|
async def get_schedule_diagnostics_endpoint():
|
||||||
|
"""
|
||||||
|
Get comprehensive diagnostic information about schedule health.
|
||||||
|
|
||||||
|
Returns schedule metrics including:
|
||||||
|
- Total schedules (user vs system)
|
||||||
|
- Orphaned schedules by category
|
||||||
|
- Upcoming executions
|
||||||
|
"""
|
||||||
|
logger.info("Getting schedule diagnostics")
|
||||||
|
|
||||||
|
diagnostics = await get_schedule_health_metrics()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Schedule diagnostics: total={diagnostics.total_schedules}, "
|
||||||
|
f"user={diagnostics.user_schedules}, "
|
||||||
|
f"orphaned={diagnostics.total_orphaned}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return diagnostics
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/diagnostics/schedules/all",
|
||||||
|
response_model=SchedulesListResponse,
|
||||||
|
summary="List All User Schedules",
|
||||||
|
)
|
||||||
|
async def list_all_schedules(
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get detailed list of all user schedules (excludes system monitoring jobs).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number of schedules to return (default 100)
|
||||||
|
offset: Number of schedules to skip (default 0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of schedules with details
|
||||||
|
"""
|
||||||
|
logger.info(f"Listing all schedules (limit={limit}, offset={offset})")
|
||||||
|
|
||||||
|
schedules = await get_all_schedules_details(limit=limit, offset=offset)
|
||||||
|
|
||||||
|
# Get total count
|
||||||
|
diagnostics = await get_schedule_health_metrics()
|
||||||
|
total = diagnostics.user_schedules
|
||||||
|
|
||||||
|
return SchedulesListResponse(schedules=schedules, total=total)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/diagnostics/schedules/orphaned",
|
||||||
|
response_model=OrphanedSchedulesListResponse,
|
||||||
|
summary="List Orphaned Schedules",
|
||||||
|
)
|
||||||
|
async def list_orphaned_schedules():
|
||||||
|
"""
|
||||||
|
Get detailed list of orphaned schedules with orphan reasons.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of orphaned schedules categorized by orphan type
|
||||||
|
"""
|
||||||
|
logger.info("Listing orphaned schedules")
|
||||||
|
|
||||||
|
schedules = await get_orphaned_schedules_details()
|
||||||
|
|
||||||
|
return OrphanedSchedulesListResponse(schedules=schedules, total=len(schedules))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/diagnostics/schedules/cleanup-orphaned",
|
||||||
|
response_model=ScheduleCleanupResponse,
|
||||||
|
summary="Cleanup Orphaned Schedules",
|
||||||
|
)
|
||||||
|
async def cleanup_orphaned_schedules(
|
||||||
|
request: ScheduleCleanupRequest,
|
||||||
|
user: AuthUser = Security(requires_admin_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Cleanup orphaned schedules by deleting from scheduler (admin only).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Contains list of schedule_ids to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of schedules deleted and success message
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Admin {user.user_id} cleaning up {len(request.schedule_ids)} orphaned schedules"
|
||||||
|
)
|
||||||
|
|
||||||
|
deleted_count = await cleanup_orphaned_schedules_bulk(
|
||||||
|
request.schedule_ids, user.user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return ScheduleCleanupResponse(
|
||||||
|
success=deleted_count > 0,
|
||||||
|
deleted_count=deleted_count,
|
||||||
|
message=f"Deleted {deleted_count} of {len(request.schedule_ids)} orphaned schedules",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/diagnostics/executions/stop-all-long-running",
|
||||||
|
response_model=StopExecutionResponse,
|
||||||
|
summary="Stop ALL Long-Running Executions",
|
||||||
|
)
|
||||||
|
async def stop_all_long_running_executions_endpoint(
|
||||||
|
user: AuthUser = Security(requires_admin_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Stop ALL long-running executions (RUNNING >24h) by sending cancel signals (admin only).
|
||||||
|
Operates on entire dataset, not limited to pagination.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of executions stopped and success message
|
||||||
|
"""
|
||||||
|
logger.info(f"Admin {user.user_id} stopping ALL long-running executions")
|
||||||
|
|
||||||
|
stopped_count = await stop_all_long_running_executions(user.user_id)
|
||||||
|
|
||||||
|
return StopExecutionResponse(
|
||||||
|
success=stopped_count > 0,
|
||||||
|
stopped_count=stopped_count,
|
||||||
|
message=f"Stopped {stopped_count} long-running executions",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/diagnostics/executions/cleanup-all-orphaned",
|
||||||
|
response_model=StopExecutionResponse,
|
||||||
|
summary="Cleanup ALL Orphaned Executions",
|
||||||
|
)
|
||||||
|
async def cleanup_all_orphaned_executions(
|
||||||
|
user: AuthUser = Security(requires_admin_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Cleanup ALL orphaned executions (>24h old) by directly updating DB status.
|
||||||
|
Operates on all executions, not just paginated results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of executions cleaned up and success message
|
||||||
|
"""
|
||||||
|
logger.info(f"Admin {user.user_id} cleaning up ALL orphaned executions")
|
||||||
|
|
||||||
|
# Fetch all orphaned execution IDs
|
||||||
|
execution_ids = await get_all_orphaned_execution_ids()
|
||||||
|
|
||||||
|
if not execution_ids:
|
||||||
|
return StopExecutionResponse(
|
||||||
|
success=True,
|
||||||
|
stopped_count=0,
|
||||||
|
message="No orphaned executions to cleanup",
|
||||||
|
)
|
||||||
|
|
||||||
|
cleaned_count = await cleanup_orphaned_executions_bulk(execution_ids, user.user_id)
|
||||||
|
|
||||||
|
return StopExecutionResponse(
|
||||||
|
success=cleaned_count > 0,
|
||||||
|
stopped_count=cleaned_count,
|
||||||
|
message=f"Cleaned up {cleaned_count} orphaned executions",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/diagnostics/executions/cleanup-all-stuck-queued",
|
||||||
|
response_model=StopExecutionResponse,
|
||||||
|
summary="Cleanup ALL Stuck Queued Executions",
|
||||||
|
)
|
||||||
|
async def cleanup_all_stuck_queued_executions_endpoint(
|
||||||
|
user: AuthUser = Security(requires_admin_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Cleanup ALL stuck queued executions (QUEUED >1h) by updating DB status (admin only).
|
||||||
|
Operates on entire dataset, not limited to pagination.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of executions cleaned up and success message
|
||||||
|
"""
|
||||||
|
logger.info(f"Admin {user.user_id} cleaning up ALL stuck queued executions")
|
||||||
|
|
||||||
|
cleaned_count = await cleanup_all_stuck_queued_executions(user.user_id)
|
||||||
|
|
||||||
|
return StopExecutionResponse(
|
||||||
|
success=cleaned_count > 0,
|
||||||
|
stopped_count=cleaned_count,
|
||||||
|
message=f"Cleaned up {cleaned_count} stuck queued executions",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/diagnostics/executions/requeue-all-stuck",
|
||||||
|
response_model=RequeueExecutionResponse,
|
||||||
|
summary="Requeue ALL Stuck Queued Executions",
|
||||||
|
)
|
||||||
|
async def requeue_all_stuck_executions(
|
||||||
|
user: AuthUser = Security(requires_admin_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Requeue ALL stuck queued executions (QUEUED >1h) by publishing to RabbitMQ.
|
||||||
|
Operates on all executions, not just paginated results.
|
||||||
|
|
||||||
|
Uses add_graph_execution with existing graph_exec_id to requeue.
|
||||||
|
|
||||||
|
⚠️ WARNING: This will re-execute ALL stuck executions and may cost significant credits.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of executions requeued and success message
|
||||||
|
"""
|
||||||
|
logger.info(f"Admin {user.user_id} requeueing ALL stuck queued executions")
|
||||||
|
|
||||||
|
# Fetch all stuck queued execution IDs
|
||||||
|
execution_ids = await get_all_stuck_queued_execution_ids()
|
||||||
|
|
||||||
|
if not execution_ids:
|
||||||
|
return RequeueExecutionResponse(
|
||||||
|
success=True,
|
||||||
|
requeued_count=0,
|
||||||
|
message="No stuck queued executions to requeue",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get stuck executions by ID list (must be QUEUED)
|
||||||
|
executions = await get_graph_executions(
|
||||||
|
execution_ids=execution_ids,
|
||||||
|
statuses=[AgentExecutionStatus.QUEUED],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Requeue all in parallel using add_graph_execution
|
||||||
|
async def requeue_one(exec) -> bool:
|
||||||
|
try:
|
||||||
|
await add_graph_execution(
|
||||||
|
graph_id=exec.graph_id,
|
||||||
|
user_id=exec.user_id,
|
||||||
|
graph_version=exec.graph_version,
|
||||||
|
graph_exec_id=exec.id, # Requeue existing
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to requeue {exec.id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[requeue_one(exec) for exec in executions], return_exceptions=False
|
||||||
|
)
|
||||||
|
|
||||||
|
requeued_count = sum(1 for success in results if success)
|
||||||
|
|
||||||
|
return RequeueExecutionResponse(
|
||||||
|
success=requeued_count > 0,
|
||||||
|
requeued_count=requeued_count,
|
||||||
|
message=f"Requeued {requeued_count} stuck executions",
|
||||||
|
)
|
||||||
@@ -0,0 +1,889 @@
|
|||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
import fastapi.testclient
|
||||||
|
import pytest
|
||||||
|
import pytest_mock
|
||||||
|
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||||
|
from prisma.enums import AgentExecutionStatus
|
||||||
|
|
||||||
|
import backend.api.features.admin.diagnostics_admin_routes as diagnostics_admin_routes
|
||||||
|
from backend.data.diagnostics import (
|
||||||
|
AgentDiagnosticsSummary,
|
||||||
|
ExecutionDiagnosticsSummary,
|
||||||
|
FailedExecutionDetail,
|
||||||
|
OrphanedScheduleDetail,
|
||||||
|
RunningExecutionDetail,
|
||||||
|
ScheduleDetail,
|
||||||
|
ScheduleHealthMetrics,
|
||||||
|
)
|
||||||
|
from backend.data.execution import GraphExecutionMeta
|
||||||
|
|
||||||
|
app = fastapi.FastAPI()
|
||||||
|
app.include_router(diagnostics_admin_routes.router)
|
||||||
|
|
||||||
|
client = fastapi.testclient.TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_app_admin_auth(mock_jwt_admin):
|
||||||
|
"""Setup admin auth overrides for all tests in this module"""
|
||||||
|
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_execution_diagnostics_success(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
):
|
||||||
|
"""Test fetching execution diagnostics with invalid state detection"""
|
||||||
|
mock_diagnostics = ExecutionDiagnosticsSummary(
|
||||||
|
running_count=10,
|
||||||
|
queued_db_count=5,
|
||||||
|
rabbitmq_queue_depth=3,
|
||||||
|
cancel_queue_depth=0,
|
||||||
|
orphaned_running=2,
|
||||||
|
orphaned_queued=1,
|
||||||
|
failed_count_1h=5,
|
||||||
|
failed_count_24h=20,
|
||||||
|
failure_rate_24h=0.83,
|
||||||
|
stuck_running_24h=1,
|
||||||
|
stuck_running_1h=3,
|
||||||
|
oldest_running_hours=26.5,
|
||||||
|
stuck_queued_1h=2,
|
||||||
|
queued_never_started=1,
|
||||||
|
invalid_queued_with_start=1, # New invalid state
|
||||||
|
invalid_running_without_start=1, # New invalid state
|
||||||
|
completed_1h=50,
|
||||||
|
completed_24h=1200,
|
||||||
|
throughput_per_hour=50.0,
|
||||||
|
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||||
|
return_value=mock_diagnostics,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/admin/diagnostics/executions")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Verify new invalid state fields are included
|
||||||
|
assert data["invalid_queued_with_start"] == 1
|
||||||
|
assert data["invalid_running_without_start"] == 1
|
||||||
|
# Verify all expected fields present
|
||||||
|
assert "running_executions" in data
|
||||||
|
assert "orphaned_running" in data
|
||||||
|
assert "failed_count_24h" in data
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_invalid_executions(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
):
|
||||||
|
"""Test listing executions in invalid states (read-only endpoint)"""
|
||||||
|
mock_invalid_executions = [
|
||||||
|
RunningExecutionDetail(
|
||||||
|
execution_id="exec-invalid-1",
|
||||||
|
graph_id="graph-123",
|
||||||
|
graph_name="Test Graph",
|
||||||
|
graph_version=1,
|
||||||
|
user_id="user-123",
|
||||||
|
user_email="test@example.com",
|
||||||
|
status="QUEUED",
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
started_at=datetime.now(
|
||||||
|
timezone.utc
|
||||||
|
), # QUEUED but has startedAt - INVALID!
|
||||||
|
queue_status=None,
|
||||||
|
),
|
||||||
|
RunningExecutionDetail(
|
||||||
|
execution_id="exec-invalid-2",
|
||||||
|
graph_id="graph-456",
|
||||||
|
graph_name="Another Graph",
|
||||||
|
graph_version=2,
|
||||||
|
user_id="user-456",
|
||||||
|
user_email="user@example.com",
|
||||||
|
status="RUNNING",
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
started_at=None, # RUNNING but no startedAt - INVALID!
|
||||||
|
queue_status=None,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_diagnostics = ExecutionDiagnosticsSummary(
|
||||||
|
running_count=10,
|
||||||
|
queued_db_count=5,
|
||||||
|
rabbitmq_queue_depth=3,
|
||||||
|
cancel_queue_depth=0,
|
||||||
|
orphaned_running=0,
|
||||||
|
orphaned_queued=0,
|
||||||
|
failed_count_1h=0,
|
||||||
|
failed_count_24h=0,
|
||||||
|
failure_rate_24h=0.0,
|
||||||
|
stuck_running_24h=0,
|
||||||
|
stuck_running_1h=0,
|
||||||
|
oldest_running_hours=None,
|
||||||
|
stuck_queued_1h=0,
|
||||||
|
queued_never_started=0,
|
||||||
|
invalid_queued_with_start=1,
|
||||||
|
invalid_running_without_start=1,
|
||||||
|
completed_1h=0,
|
||||||
|
completed_24h=0,
|
||||||
|
throughput_per_hour=0.0,
|
||||||
|
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_invalid_executions_details",
|
||||||
|
return_value=mock_invalid_executions,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||||
|
return_value=mock_diagnostics,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/admin/diagnostics/executions/invalid?limit=100&offset=0")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 2 # Sum of both invalid state types
|
||||||
|
assert len(data["executions"]) == 2
|
||||||
|
# Verify both types of invalid states are returned
|
||||||
|
assert data["executions"][0]["execution_id"] in [
|
||||||
|
"exec-invalid-1",
|
||||||
|
"exec-invalid-2",
|
||||||
|
]
|
||||||
|
assert data["executions"][1]["execution_id"] in [
|
||||||
|
"exec-invalid-1",
|
||||||
|
"exec-invalid-2",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_requeue_single_execution_with_add_graph_execution(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
admin_user_id: str,
|
||||||
|
):
|
||||||
|
"""Test requeueing uses add_graph_execution in requeue mode"""
|
||||||
|
mock_exec_meta = GraphExecutionMeta(
|
||||||
|
id="exec-stuck-123",
|
||||||
|
user_id="user-123",
|
||||||
|
graph_id="graph-456",
|
||||||
|
graph_version=1,
|
||||||
|
inputs=None,
|
||||||
|
credential_inputs=None,
|
||||||
|
nodes_input_masks=None,
|
||||||
|
preset_id=None,
|
||||||
|
status=AgentExecutionStatus.QUEUED,
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
ended_at=datetime.now(timezone.utc),
|
||||||
|
stats=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||||
|
return_value=[mock_exec_meta],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_add_graph_execution = mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.add_graph_execution",
|
||||||
|
return_value=AsyncMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/admin/diagnostics/executions/requeue",
|
||||||
|
json={"execution_id": "exec-stuck-123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert data["requeued_count"] == 1
|
||||||
|
|
||||||
|
# Verify it used add_graph_execution in requeue mode
|
||||||
|
mock_add_graph_execution.assert_called_once()
|
||||||
|
call_kwargs = mock_add_graph_execution.call_args.kwargs
|
||||||
|
assert call_kwargs["graph_exec_id"] == "exec-stuck-123" # Requeue mode!
|
||||||
|
assert call_kwargs["graph_id"] == "graph-456"
|
||||||
|
assert call_kwargs["user_id"] == "user-123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_single_execution_with_stop_graph_execution(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
admin_user_id: str,
|
||||||
|
):
|
||||||
|
"""Test stopping uses robust stop_graph_execution"""
|
||||||
|
mock_exec_meta = GraphExecutionMeta(
|
||||||
|
id="exec-running-123",
|
||||||
|
user_id="user-789",
|
||||||
|
graph_id="graph-999",
|
||||||
|
graph_version=2,
|
||||||
|
inputs=None,
|
||||||
|
credential_inputs=None,
|
||||||
|
nodes_input_masks=None,
|
||||||
|
preset_id=None,
|
||||||
|
status=AgentExecutionStatus.RUNNING,
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
ended_at=datetime.now(timezone.utc),
|
||||||
|
stats=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||||
|
return_value=[mock_exec_meta],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_stop_graph_execution = mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.stop_graph_execution",
|
||||||
|
return_value=AsyncMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/admin/diagnostics/executions/stop",
|
||||||
|
json={"execution_id": "exec-running-123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert data["stopped_count"] == 1
|
||||||
|
|
||||||
|
# Verify it used stop_graph_execution with cascade
|
||||||
|
mock_stop_graph_execution.assert_called_once()
|
||||||
|
call_kwargs = mock_stop_graph_execution.call_args.kwargs
|
||||||
|
assert call_kwargs["graph_exec_id"] == "exec-running-123"
|
||||||
|
assert call_kwargs["user_id"] == "user-789"
|
||||||
|
assert call_kwargs["cascade"] is True # Stops children too!
|
||||||
|
assert call_kwargs["wait_timeout"] == 15.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_requeue_not_queued_execution_fails(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
):
|
||||||
|
"""Test that requeue fails if execution is not in QUEUED status"""
|
||||||
|
# Mock an execution that's RUNNING (not QUEUED)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||||
|
return_value=[], # No QUEUED executions found
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/admin/diagnostics/executions/requeue",
|
||||||
|
json={"execution_id": "exec-running-123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert "not found or not in QUEUED status" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_invalid_executions_no_bulk_actions(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
):
|
||||||
|
"""Verify invalid executions endpoint is read-only (no bulk actions)"""
|
||||||
|
# This is a documentation test - the endpoint exists but should not
|
||||||
|
# have corresponding cleanup/stop/requeue endpoints
|
||||||
|
|
||||||
|
# These endpoints should NOT exist for invalid states:
|
||||||
|
invalid_bulk_endpoints = [
|
||||||
|
"/admin/diagnostics/executions/cleanup-invalid",
|
||||||
|
"/admin/diagnostics/executions/stop-invalid",
|
||||||
|
"/admin/diagnostics/executions/requeue-invalid",
|
||||||
|
]
|
||||||
|
|
||||||
|
for endpoint in invalid_bulk_endpoints:
|
||||||
|
response = client.post(endpoint, json={"execution_ids": ["test"]})
|
||||||
|
assert response.status_code == 404, f"{endpoint} should not exist (read-only)"
|
||||||
|
|
||||||
|
|
||||||
|
def test_execution_ids_filter_efficiency(
|
||||||
|
mocker: pytest_mock.MockFixture,
|
||||||
|
):
|
||||||
|
"""Test that bulk operations use efficient execution_ids filter"""
|
||||||
|
mock_exec_metas = [
|
||||||
|
GraphExecutionMeta(
|
||||||
|
id=f"exec-{i}",
|
||||||
|
user_id=f"user-{i}",
|
||||||
|
graph_id="graph-123",
|
||||||
|
graph_version=1,
|
||||||
|
inputs=None,
|
||||||
|
credential_inputs=None,
|
||||||
|
nodes_input_masks=None,
|
||||||
|
preset_id=None,
|
||||||
|
status=AgentExecutionStatus.QUEUED,
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
ended_at=datetime.now(timezone.utc),
|
||||||
|
stats=None,
|
||||||
|
)
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_get_graph_executions = mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||||
|
return_value=mock_exec_metas,
|
||||||
|
)
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.add_graph_execution",
|
||||||
|
return_value=AsyncMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/admin/diagnostics/executions/requeue-bulk",
|
||||||
|
json={"execution_ids": ["exec-0", "exec-1", "exec-2"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Verify it used execution_ids filter (not fetching all queued)
|
||||||
|
mock_get_graph_executions.assert_called_once()
|
||||||
|
call_kwargs = mock_get_graph_executions.call_args.kwargs
|
||||||
|
assert "execution_ids" in call_kwargs
|
||||||
|
assert call_kwargs["execution_ids"] == ["exec-0", "exec-1", "exec-2"]
|
||||||
|
assert call_kwargs["statuses"] == [AgentExecutionStatus.QUEUED]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helper: reusable mock diagnostics summary
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_diagnostics(**overrides) -> ExecutionDiagnosticsSummary:
|
||||||
|
defaults = dict(
|
||||||
|
running_count=10,
|
||||||
|
queued_db_count=5,
|
||||||
|
rabbitmq_queue_depth=3,
|
||||||
|
cancel_queue_depth=0,
|
||||||
|
orphaned_running=2,
|
||||||
|
orphaned_queued=1,
|
||||||
|
failed_count_1h=5,
|
||||||
|
failed_count_24h=20,
|
||||||
|
failure_rate_24h=0.83,
|
||||||
|
stuck_running_24h=3,
|
||||||
|
stuck_running_1h=5,
|
||||||
|
oldest_running_hours=26.5,
|
||||||
|
stuck_queued_1h=2,
|
||||||
|
queued_never_started=1,
|
||||||
|
invalid_queued_with_start=1,
|
||||||
|
invalid_running_without_start=1,
|
||||||
|
completed_1h=50,
|
||||||
|
completed_24h=1200,
|
||||||
|
throughput_per_hour=50.0,
|
||||||
|
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||||
|
)
|
||||||
|
defaults.update(overrides)
|
||||||
|
return ExecutionDiagnosticsSummary(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
_SENTINEL = object()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_execution(
|
||||||
|
exec_id: str = "exec-1",
|
||||||
|
status: str = "RUNNING",
|
||||||
|
started_at: datetime | None | object = _SENTINEL,
|
||||||
|
) -> RunningExecutionDetail:
|
||||||
|
return RunningExecutionDetail(
|
||||||
|
execution_id=exec_id,
|
||||||
|
graph_id="graph-123",
|
||||||
|
graph_name="Test Graph",
|
||||||
|
graph_version=1,
|
||||||
|
user_id="user-123",
|
||||||
|
user_email="test@example.com",
|
||||||
|
status=status,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
started_at=(
|
||||||
|
datetime.now(timezone.utc) if started_at is _SENTINEL else started_at
|
||||||
|
),
|
||||||
|
queue_status=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_failed_execution(
|
||||||
|
exec_id: str = "exec-fail-1",
|
||||||
|
) -> FailedExecutionDetail:
|
||||||
|
return FailedExecutionDetail(
|
||||||
|
execution_id=exec_id,
|
||||||
|
graph_id="graph-123",
|
||||||
|
graph_name="Test Graph",
|
||||||
|
graph_version=1,
|
||||||
|
user_id="user-123",
|
||||||
|
user_email="test@example.com",
|
||||||
|
status="FAILED",
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
failed_at=datetime.now(timezone.utc),
|
||||||
|
error_message="Something went wrong",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_schedule_health(**overrides) -> ScheduleHealthMetrics:
|
||||||
|
defaults = dict(
|
||||||
|
total_schedules=15,
|
||||||
|
user_schedules=10,
|
||||||
|
system_schedules=5,
|
||||||
|
orphaned_deleted_graph=2,
|
||||||
|
orphaned_no_library_access=1,
|
||||||
|
orphaned_invalid_credentials=0,
|
||||||
|
orphaned_validation_failed=0,
|
||||||
|
total_orphaned=3,
|
||||||
|
schedules_next_hour=4,
|
||||||
|
schedules_next_24h=8,
|
||||||
|
total_runs_next_hour=12,
|
||||||
|
total_runs_next_24h=48,
|
||||||
|
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||||
|
)
|
||||||
|
defaults.update(overrides)
|
||||||
|
return ScheduleHealthMetrics(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET endpoints: execution list variants
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_running_executions(mocker: pytest_mock.MockFixture):
|
||||||
|
mock_execs = [
|
||||||
|
_make_mock_execution("exec-run-1"),
|
||||||
|
_make_mock_execution("exec-run-2"),
|
||||||
|
]
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_running_executions_details",
|
||||||
|
return_value=mock_execs,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||||
|
return_value=_make_mock_diagnostics(),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/admin/diagnostics/executions/running?limit=50&offset=0")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 15 # running_count(10) + queued_db_count(5)
|
||||||
|
assert len(data["executions"]) == 2
|
||||||
|
assert data["executions"][0]["execution_id"] == "exec-run-1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_orphaned_executions(mocker: pytest_mock.MockFixture):
|
||||||
|
mock_execs = [_make_mock_execution("exec-orphan-1", status="RUNNING")]
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_orphaned_executions_details",
|
||||||
|
return_value=mock_execs,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||||
|
return_value=_make_mock_diagnostics(),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/admin/diagnostics/executions/orphaned?limit=50&offset=0")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 3 # orphaned_running(2) + orphaned_queued(1)
|
||||||
|
assert len(data["executions"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_failed_executions(mocker: pytest_mock.MockFixture):
|
||||||
|
mock_execs = [_make_mock_failed_execution("exec-fail-1")]
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_failed_executions_details",
|
||||||
|
return_value=mock_execs,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_failed_executions_count",
|
||||||
|
return_value=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
"/admin/diagnostics/executions/failed?limit=50&offset=0&hours=24"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 42
|
||||||
|
assert len(data["executions"]) == 1
|
||||||
|
assert data["executions"][0]["error_message"] == "Something went wrong"
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_long_running_executions(mocker: pytest_mock.MockFixture):
|
||||||
|
mock_execs = [_make_mock_execution("exec-long-1")]
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_long_running_executions_details",
|
||||||
|
return_value=mock_execs,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||||
|
return_value=_make_mock_diagnostics(),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
"/admin/diagnostics/executions/long-running?limit=50&offset=0"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 3 # stuck_running_24h
|
||||||
|
assert len(data["executions"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_stuck_queued_executions(mocker: pytest_mock.MockFixture):
|
||||||
|
mock_execs = [
|
||||||
|
_make_mock_execution("exec-stuck-1", status="QUEUED", started_at=None)
|
||||||
|
]
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_stuck_queued_executions_details",
|
||||||
|
return_value=mock_execs,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
|
||||||
|
return_value=_make_mock_diagnostics(),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
"/admin/diagnostics/executions/stuck-queued?limit=50&offset=0"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 2 # stuck_queued_1h
|
||||||
|
assert len(data["executions"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET endpoints: agent + schedule diagnostics
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_agent_diagnostics(mocker: pytest_mock.MockFixture):
|
||||||
|
mock_diag = AgentDiagnosticsSummary(
|
||||||
|
agents_with_active_executions=7,
|
||||||
|
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_agent_diagnostics",
|
||||||
|
return_value=mock_diag,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/admin/diagnostics/agents")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["agents_with_active_executions"] == 7
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_schedule_diagnostics(mocker: pytest_mock.MockFixture):
|
||||||
|
mock_metrics = _make_mock_schedule_health()
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_schedule_health_metrics",
|
||||||
|
return_value=mock_metrics,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/admin/diagnostics/schedules")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["user_schedules"] == 10
|
||||||
|
assert data["total_orphaned"] == 3
|
||||||
|
assert data["total_runs_next_hour"] == 12
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_all_schedules(mocker: pytest_mock.MockFixture):
|
||||||
|
mock_schedules = [
|
||||||
|
ScheduleDetail(
|
||||||
|
schedule_id="sched-1",
|
||||||
|
schedule_name="Daily Run",
|
||||||
|
graph_id="graph-1",
|
||||||
|
graph_name="My Agent",
|
||||||
|
graph_version=1,
|
||||||
|
user_id="user-1",
|
||||||
|
user_email="alice@example.com",
|
||||||
|
cron="0 9 * * *",
|
||||||
|
timezone="UTC",
|
||||||
|
next_run_time=datetime.now(timezone.utc).isoformat(),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_all_schedules_details",
|
||||||
|
return_value=mock_schedules,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_schedule_health_metrics",
|
||||||
|
return_value=_make_mock_schedule_health(),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/admin/diagnostics/schedules/all?limit=50&offset=0")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 10
|
||||||
|
assert len(data["schedules"]) == 1
|
||||||
|
assert data["schedules"][0]["schedule_name"] == "Daily Run"
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_orphaned_schedules(mocker: pytest_mock.MockFixture):
|
||||||
|
mock_orphans = [
|
||||||
|
OrphanedScheduleDetail(
|
||||||
|
schedule_id="sched-orphan-1",
|
||||||
|
schedule_name="Ghost Schedule",
|
||||||
|
graph_id="graph-deleted",
|
||||||
|
graph_version=1,
|
||||||
|
user_id="user-1",
|
||||||
|
orphan_reason="deleted_graph",
|
||||||
|
error_detail=None,
|
||||||
|
next_run_time=datetime.now(timezone.utc).isoformat(),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_orphaned_schedules_details",
|
||||||
|
return_value=mock_orphans,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/admin/diagnostics/schedules/orphaned")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 1
|
||||||
|
assert data["schedules"][0]["orphan_reason"] == "deleted_graph"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST endpoints: bulk stop, cleanup, requeue
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_multiple_executions(mocker: pytest_mock.MockFixture):
|
||||||
|
mock_exec_metas = [
|
||||||
|
GraphExecutionMeta(
|
||||||
|
id=f"exec-{i}",
|
||||||
|
user_id=f"user-{i}",
|
||||||
|
graph_id="graph-123",
|
||||||
|
graph_version=1,
|
||||||
|
inputs=None,
|
||||||
|
credential_inputs=None,
|
||||||
|
nodes_input_masks=None,
|
||||||
|
preset_id=None,
|
||||||
|
status=AgentExecutionStatus.RUNNING,
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
ended_at=None,
|
||||||
|
stats=None,
|
||||||
|
)
|
||||||
|
for i in range(2)
|
||||||
|
]
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||||
|
return_value=mock_exec_metas,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.stop_graph_execution",
|
||||||
|
return_value=AsyncMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/admin/diagnostics/executions/stop-bulk",
|
||||||
|
json={"execution_ids": ["exec-0", "exec-1"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert data["stopped_count"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_multiple_executions_none_found(mocker: pytest_mock.MockFixture):
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||||
|
return_value=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/admin/diagnostics/executions/stop-bulk",
|
||||||
|
json={"execution_ids": ["nonexistent"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is False
|
||||||
|
assert data["stopped_count"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_cleanup_orphaned_executions(mocker: pytest_mock.MockFixture):
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_executions_bulk",
|
||||||
|
return_value=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/admin/diagnostics/executions/cleanup-orphaned",
|
||||||
|
json={"execution_ids": ["exec-1", "exec-2", "exec-3"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert data["stopped_count"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_cleanup_orphaned_schedules(mocker: pytest_mock.MockFixture):
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_schedules_bulk",
|
||||||
|
return_value=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/admin/diagnostics/schedules/cleanup-orphaned",
|
||||||
|
json={"schedule_ids": ["sched-1", "sched-2"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert data["deleted_count"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_all_long_running_executions(mocker: pytest_mock.MockFixture):
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.stop_all_long_running_executions",
|
||||||
|
return_value=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post("/admin/diagnostics/executions/stop-all-long-running")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert data["stopped_count"] == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_cleanup_all_orphaned_executions(mocker: pytest_mock.MockFixture):
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_all_orphaned_execution_ids",
|
||||||
|
return_value=["exec-1", "exec-2"],
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_executions_bulk",
|
||||||
|
return_value=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post("/admin/diagnostics/executions/cleanup-all-orphaned")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert data["stopped_count"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_cleanup_all_orphaned_executions_none(mocker: pytest_mock.MockFixture):
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_all_orphaned_execution_ids",
|
||||||
|
return_value=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post("/admin/diagnostics/executions/cleanup-all-orphaned")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert data["stopped_count"] == 0
|
||||||
|
assert "No orphaned" in data["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_cleanup_all_stuck_queued_executions(mocker: pytest_mock.MockFixture):
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.cleanup_all_stuck_queued_executions",
|
||||||
|
return_value=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post("/admin/diagnostics/executions/cleanup-all-stuck-queued")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert data["stopped_count"] == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_requeue_all_stuck_executions(mocker: pytest_mock.MockFixture):
|
||||||
|
mock_exec_metas = [
|
||||||
|
GraphExecutionMeta(
|
||||||
|
id=f"exec-stuck-{i}",
|
||||||
|
user_id=f"user-{i}",
|
||||||
|
graph_id="graph-123",
|
||||||
|
graph_version=1,
|
||||||
|
inputs=None,
|
||||||
|
credential_inputs=None,
|
||||||
|
nodes_input_masks=None,
|
||||||
|
preset_id=None,
|
||||||
|
status=AgentExecutionStatus.QUEUED,
|
||||||
|
started_at=None,
|
||||||
|
ended_at=None,
|
||||||
|
stats=None,
|
||||||
|
)
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_all_stuck_queued_execution_ids",
|
||||||
|
return_value=["exec-stuck-0", "exec-stuck-1", "exec-stuck-2"],
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||||
|
return_value=mock_exec_metas,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.add_graph_execution",
|
||||||
|
return_value=AsyncMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post("/admin/diagnostics/executions/requeue-all-stuck")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert data["requeued_count"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_requeue_all_stuck_executions_none(mocker: pytest_mock.MockFixture):
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_all_stuck_queued_execution_ids",
|
||||||
|
return_value=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post("/admin/diagnostics/executions/requeue-all-stuck")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert data["requeued_count"] == 0
|
||||||
|
assert "No stuck" in data["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_requeue_bulk_none_found(mocker: pytest_mock.MockFixture):
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||||
|
return_value=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/admin/diagnostics/executions/requeue-bulk",
|
||||||
|
json={"execution_ids": ["nonexistent"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is False
|
||||||
|
assert data["requeued_count"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_single_execution_not_found(mocker: pytest_mock.MockFixture):
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
|
||||||
|
return_value=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/admin/diagnostics/executions/stop",
|
||||||
|
json={"execution_id": "nonexistent"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert "not found" in response.json()["detail"]
|
||||||
@@ -14,3 +14,70 @@ class UserHistoryResponse(BaseModel):
|
|||||||
class AddUserCreditsResponse(BaseModel):
|
class AddUserCreditsResponse(BaseModel):
|
||||||
new_balance: int
|
new_balance: int
|
||||||
transaction_key: str
|
transaction_key: str
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionDiagnosticsResponse(BaseModel):
|
||||||
|
"""Response model for execution diagnostics"""
|
||||||
|
|
||||||
|
# Current execution state
|
||||||
|
running_executions: int
|
||||||
|
queued_executions_db: int
|
||||||
|
queued_executions_rabbitmq: int
|
||||||
|
cancel_queue_depth: int
|
||||||
|
|
||||||
|
# Orphaned execution detection
|
||||||
|
orphaned_running: int
|
||||||
|
orphaned_queued: int
|
||||||
|
|
||||||
|
# Failure metrics
|
||||||
|
failed_count_1h: int
|
||||||
|
failed_count_24h: int
|
||||||
|
failure_rate_24h: float
|
||||||
|
|
||||||
|
# Long-running detection
|
||||||
|
stuck_running_24h: int
|
||||||
|
stuck_running_1h: int
|
||||||
|
oldest_running_hours: float | None
|
||||||
|
|
||||||
|
# Stuck queued detection
|
||||||
|
stuck_queued_1h: int
|
||||||
|
queued_never_started: int
|
||||||
|
|
||||||
|
# Invalid state detection (data corruption - no auto-actions)
|
||||||
|
invalid_queued_with_start: int
|
||||||
|
invalid_running_without_start: int
|
||||||
|
|
||||||
|
# Throughput metrics
|
||||||
|
completed_1h: int
|
||||||
|
completed_24h: int
|
||||||
|
throughput_per_hour: float
|
||||||
|
|
||||||
|
timestamp: str
|
||||||
|
|
||||||
|
|
||||||
|
class AgentDiagnosticsResponse(BaseModel):
|
||||||
|
"""Response model for agent diagnostics"""
|
||||||
|
|
||||||
|
agents_with_active_executions: int
|
||||||
|
timestamp: str
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduleHealthMetrics(BaseModel):
|
||||||
|
"""Response model for schedule diagnostics"""
|
||||||
|
|
||||||
|
total_schedules: int
|
||||||
|
user_schedules: int
|
||||||
|
system_schedules: int
|
||||||
|
|
||||||
|
# Orphan detection
|
||||||
|
orphaned_deleted_graph: int
|
||||||
|
orphaned_no_library_access: int
|
||||||
|
orphaned_invalid_credentials: int
|
||||||
|
orphaned_validation_failed: int
|
||||||
|
total_orphaned: int
|
||||||
|
|
||||||
|
# Upcoming
|
||||||
|
schedules_next_hour: int
|
||||||
|
schedules_next_24h: int
|
||||||
|
|
||||||
|
timestamp: str
|
||||||
|
|||||||
@@ -32,10 +32,10 @@ router = APIRouter(
|
|||||||
class UserRateLimitResponse(BaseModel):
|
class UserRateLimitResponse(BaseModel):
|
||||||
user_id: str
|
user_id: str
|
||||||
user_email: Optional[str] = None
|
user_email: Optional[str] = None
|
||||||
daily_token_limit: int
|
daily_cost_limit_microdollars: int
|
||||||
weekly_token_limit: int
|
weekly_cost_limit_microdollars: int
|
||||||
daily_tokens_used: int
|
daily_cost_used_microdollars: int
|
||||||
weekly_tokens_used: int
|
weekly_cost_used_microdollars: int
|
||||||
tier: SubscriptionTier
|
tier: SubscriptionTier
|
||||||
|
|
||||||
|
|
||||||
@@ -101,17 +101,19 @@ async def get_user_rate_limit(
|
|||||||
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
|
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
|
||||||
|
|
||||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||||
resolved_id, config.daily_token_limit, config.weekly_token_limit
|
resolved_id,
|
||||||
|
config.daily_cost_limit_microdollars,
|
||||||
|
config.weekly_cost_limit_microdollars,
|
||||||
)
|
)
|
||||||
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier)
|
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier)
|
||||||
|
|
||||||
return UserRateLimitResponse(
|
return UserRateLimitResponse(
|
||||||
user_id=resolved_id,
|
user_id=resolved_id,
|
||||||
user_email=resolved_email,
|
user_email=resolved_email,
|
||||||
daily_token_limit=daily_limit,
|
daily_cost_limit_microdollars=daily_limit,
|
||||||
weekly_token_limit=weekly_limit,
|
weekly_cost_limit_microdollars=weekly_limit,
|
||||||
daily_tokens_used=usage.daily.used,
|
daily_cost_used_microdollars=usage.daily.used,
|
||||||
weekly_tokens_used=usage.weekly.used,
|
weekly_cost_used_microdollars=usage.weekly.used,
|
||||||
tier=tier,
|
tier=tier,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -141,7 +143,9 @@ async def reset_user_rate_limit(
|
|||||||
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
|
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
|
||||||
|
|
||||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
user_id,
|
||||||
|
config.daily_cost_limit_microdollars,
|
||||||
|
config.weekly_cost_limit_microdollars,
|
||||||
)
|
)
|
||||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier)
|
usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier)
|
||||||
|
|
||||||
@@ -154,10 +158,10 @@ async def reset_user_rate_limit(
|
|||||||
return UserRateLimitResponse(
|
return UserRateLimitResponse(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
user_email=resolved_email,
|
user_email=resolved_email,
|
||||||
daily_token_limit=daily_limit,
|
daily_cost_limit_microdollars=daily_limit,
|
||||||
weekly_token_limit=weekly_limit,
|
weekly_cost_limit_microdollars=weekly_limit,
|
||||||
daily_tokens_used=usage.daily.used,
|
daily_cost_used_microdollars=usage.daily.used,
|
||||||
weekly_tokens_used=usage.weekly.used,
|
weekly_cost_used_microdollars=usage.weekly.used,
|
||||||
tier=tier,
|
tier=tier,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ def _patch_rate_limit_deps(
|
|||||||
mocker.patch(
|
mocker.patch(
|
||||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
|
return_value=(2_500_000, 12_500_000, SubscriptionTier.BASIC),
|
||||||
)
|
)
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
f"{_MOCK_MODULE}.get_usage_status",
|
f"{_MOCK_MODULE}.get_usage_status",
|
||||||
@@ -85,11 +85,11 @@ def test_get_rate_limit(
|
|||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["user_id"] == target_user_id
|
assert data["user_id"] == target_user_id
|
||||||
assert data["user_email"] == _TARGET_EMAIL
|
assert data["user_email"] == _TARGET_EMAIL
|
||||||
assert data["daily_token_limit"] == 2_500_000
|
assert data["daily_cost_limit_microdollars"] == 2_500_000
|
||||||
assert data["weekly_token_limit"] == 12_500_000
|
assert data["weekly_cost_limit_microdollars"] == 12_500_000
|
||||||
assert data["daily_tokens_used"] == 500_000
|
assert data["daily_cost_used_microdollars"] == 500_000
|
||||||
assert data["weekly_tokens_used"] == 3_000_000
|
assert data["weekly_cost_used_microdollars"] == 3_000_000
|
||||||
assert data["tier"] == "FREE"
|
assert data["tier"] == "BASIC"
|
||||||
|
|
||||||
configured_snapshot.assert_match(
|
configured_snapshot.assert_match(
|
||||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||||
@@ -117,7 +117,7 @@ def test_get_rate_limit_by_email(
|
|||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["user_id"] == target_user_id
|
assert data["user_id"] == target_user_id
|
||||||
assert data["user_email"] == _TARGET_EMAIL
|
assert data["user_email"] == _TARGET_EMAIL
|
||||||
assert data["daily_token_limit"] == 2_500_000
|
assert data["daily_cost_limit_microdollars"] == 2_500_000
|
||||||
|
|
||||||
|
|
||||||
def test_get_rate_limit_by_email_not_found(
|
def test_get_rate_limit_by_email_not_found(
|
||||||
@@ -160,10 +160,10 @@ def test_reset_user_usage_daily_only(
|
|||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["daily_tokens_used"] == 0
|
assert data["daily_cost_used_microdollars"] == 0
|
||||||
# Weekly is untouched
|
# Weekly is untouched
|
||||||
assert data["weekly_tokens_used"] == 3_000_000
|
assert data["weekly_cost_used_microdollars"] == 3_000_000
|
||||||
assert data["tier"] == "FREE"
|
assert data["tier"] == "BASIC"
|
||||||
|
|
||||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
|
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
|
||||||
|
|
||||||
@@ -192,9 +192,9 @@ def test_reset_user_usage_daily_and_weekly(
|
|||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["daily_tokens_used"] == 0
|
assert data["daily_cost_used_microdollars"] == 0
|
||||||
assert data["weekly_tokens_used"] == 0
|
assert data["weekly_cost_used_microdollars"] == 0
|
||||||
assert data["tier"] == "FREE"
|
assert data["tier"] == "BASIC"
|
||||||
|
|
||||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
|
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
|
||||||
|
|
||||||
@@ -231,7 +231,7 @@ def test_get_rate_limit_email_lookup_failure(
|
|||||||
mocker.patch(
|
mocker.patch(
|
||||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
|
return_value=(2_500_000, 12_500_000, SubscriptionTier.BASIC),
|
||||||
)
|
)
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
f"{_MOCK_MODULE}.get_usage_status",
|
f"{_MOCK_MODULE}.get_usage_status",
|
||||||
@@ -324,7 +324,7 @@ def test_set_user_tier(
|
|||||||
mocker.patch(
|
mocker.patch(
|
||||||
f"{_MOCK_MODULE}.get_user_tier",
|
f"{_MOCK_MODULE}.get_user_tier",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
return_value=SubscriptionTier.FREE,
|
return_value=SubscriptionTier.BASIC,
|
||||||
)
|
)
|
||||||
mock_set = mocker.patch(
|
mock_set = mocker.patch(
|
||||||
f"{_MOCK_MODULE}.set_user_tier",
|
f"{_MOCK_MODULE}.set_user_tier",
|
||||||
@@ -347,7 +347,7 @@ def test_set_user_tier_downgrade(
|
|||||||
mocker: pytest_mock.MockerFixture,
|
mocker: pytest_mock.MockerFixture,
|
||||||
target_user_id: str,
|
target_user_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test downgrading a user's tier from PRO to FREE."""
|
"""Test downgrading a user's tier from PRO to BASIC."""
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
@@ -365,14 +365,14 @@ def test_set_user_tier_downgrade(
|
|||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/admin/rate_limit/tier",
|
"/admin/rate_limit/tier",
|
||||||
json={"user_id": target_user_id, "tier": "FREE"},
|
json={"user_id": target_user_id, "tier": "BASIC"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["user_id"] == target_user_id
|
assert data["user_id"] == target_user_id
|
||||||
assert data["tier"] == "FREE"
|
assert data["tier"] == "BASIC"
|
||||||
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.FREE)
|
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.BASIC)
|
||||||
|
|
||||||
|
|
||||||
def test_set_user_tier_invalid_tier(
|
def test_set_user_tier_invalid_tier(
|
||||||
@@ -456,7 +456,7 @@ def test_set_user_tier_db_failure(
|
|||||||
mocker.patch(
|
mocker.patch(
|
||||||
f"{_MOCK_MODULE}.get_user_tier",
|
f"{_MOCK_MODULE}.get_user_tier",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
return_value=SubscriptionTier.FREE,
|
return_value=SubscriptionTier.BASIC,
|
||||||
)
|
)
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
f"{_MOCK_MODULE}.set_user_tier",
|
f"{_MOCK_MODULE}.set_user_tier",
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
@@ -10,15 +9,14 @@ from uuid import uuid4
|
|||||||
from autogpt_libs import auth
|
from autogpt_libs import auth
|
||||||
from fastapi import APIRouter, HTTPException, Query, Response, Security
|
from fastapi import APIRouter, HTTPException, Query, Response, Security
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from prisma.models import UserWorkspaceFile
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
from backend.copilot import service as chat_service
|
from backend.copilot import service as chat_service
|
||||||
from backend.copilot import stream_registry
|
from backend.copilot import stream_registry
|
||||||
|
from backend.copilot.builder_context import resolve_session_permissions
|
||||||
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
|
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
|
||||||
from backend.copilot.db import get_chat_messages_paginated
|
from backend.copilot.db import get_chat_messages_paginated
|
||||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||||
from backend.copilot.message_dedup import acquire_dedup_lock
|
|
||||||
from backend.copilot.model import (
|
from backend.copilot.model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatSession,
|
ChatSession,
|
||||||
@@ -27,11 +25,18 @@ from backend.copilot.model import (
|
|||||||
create_chat_session,
|
create_chat_session,
|
||||||
delete_chat_session,
|
delete_chat_session,
|
||||||
get_chat_session,
|
get_chat_session,
|
||||||
|
get_or_create_builder_session,
|
||||||
get_user_sessions,
|
get_user_sessions,
|
||||||
update_session_title,
|
update_session_title,
|
||||||
)
|
)
|
||||||
|
from backend.copilot.pending_message_helpers import (
|
||||||
|
QueuePendingMessageResponse,
|
||||||
|
is_turn_in_flight,
|
||||||
|
queue_pending_for_http,
|
||||||
|
)
|
||||||
|
from backend.copilot.pending_messages import peek_pending_messages
|
||||||
from backend.copilot.rate_limit import (
|
from backend.copilot.rate_limit import (
|
||||||
CoPilotUsageStatus,
|
CoPilotUsagePublic,
|
||||||
RateLimitExceeded,
|
RateLimitExceeded,
|
||||||
acquire_reset_lock,
|
acquire_reset_lock,
|
||||||
check_rate_limit,
|
check_rate_limit,
|
||||||
@@ -42,7 +47,14 @@ from backend.copilot.rate_limit import (
|
|||||||
release_reset_lock,
|
release_reset_lock,
|
||||||
reset_daily_usage,
|
reset_daily_usage,
|
||||||
)
|
)
|
||||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
from backend.copilot.response_model import (
|
||||||
|
StreamError,
|
||||||
|
StreamFinish,
|
||||||
|
StreamFinishStep,
|
||||||
|
StreamHeartbeat,
|
||||||
|
StreamStart,
|
||||||
|
StreamStartStep,
|
||||||
|
)
|
||||||
from backend.copilot.service import strip_injected_context_for_display
|
from backend.copilot.service import strip_injected_context_for_display
|
||||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||||
from backend.copilot.tools.models import (
|
from backend.copilot.tools.models import (
|
||||||
@@ -70,13 +82,14 @@ from backend.copilot.tools.models import (
|
|||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
SetupRequirementsResponse,
|
SetupRequirementsResponse,
|
||||||
SuggestedGoalResponse,
|
SuggestedGoalResponse,
|
||||||
|
TodoWriteResponse,
|
||||||
UnderstandingUpdatedResponse,
|
UnderstandingUpdatedResponse,
|
||||||
)
|
)
|
||||||
from backend.copilot.tracking import track_user_message
|
from backend.copilot.tracking import track_user_message
|
||||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||||
from backend.data.redis_client import get_redis_async
|
from backend.data.redis_client import get_redis_async
|
||||||
from backend.data.understanding import get_business_understanding
|
from backend.data.understanding import get_business_understanding
|
||||||
from backend.data.workspace import get_or_create_workspace
|
from backend.data.workspace import build_files_block, resolve_workspace_files
|
||||||
from backend.util.exceptions import InsufficientBalanceError, NotFoundError
|
from backend.util.exceptions import InsufficientBalanceError, NotFoundError
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
@@ -86,10 +99,6 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
|
||||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _validate_and_get_session(
|
async def _validate_and_get_session(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
@@ -134,7 +143,7 @@ def _strip_injected_context(message: dict) -> dict:
|
|||||||
class StreamChatRequest(BaseModel):
|
class StreamChatRequest(BaseModel):
|
||||||
"""Request model for streaming chat with optional context."""
|
"""Request model for streaming chat with optional context."""
|
||||||
|
|
||||||
message: str
|
message: str = Field(max_length=64_000)
|
||||||
is_user_message: bool = True
|
is_user_message: bool = True
|
||||||
context: dict[str, str] | None = None # {url: str, content: str}
|
context: dict[str, str] | None = None # {url: str, content: str}
|
||||||
file_ids: list[str] | None = Field(
|
file_ids: list[str] | None = Field(
|
||||||
@@ -152,16 +161,53 @@ class StreamChatRequest(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CreateSessionRequest(BaseModel):
|
class QueuePendingMessageRequest(BaseModel):
|
||||||
"""Request model for creating a new chat session.
|
"""Request model for queueing a follow-up while a turn is running."""
|
||||||
|
|
||||||
|
message: str = Field(max_length=64_000)
|
||||||
|
context: dict[str, str] | None = None
|
||||||
|
file_ids: list[str] | None = Field(default=None, max_length=20)
|
||||||
|
|
||||||
|
|
||||||
|
class PeekPendingMessagesResponse(BaseModel):
|
||||||
|
"""Response for the pending-message peek (GET) endpoint.
|
||||||
|
|
||||||
|
Returns a read-only view of the pending buffer — messages are NOT
|
||||||
|
consumed. The frontend uses this to restore the queued-message
|
||||||
|
indicator after a page refresh and to decide when to clear it once
|
||||||
|
a turn has ended.
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages: list[str]
|
||||||
|
count: int
|
||||||
|
|
||||||
|
|
||||||
|
class CreateSessionRequest(BaseModel):
|
||||||
|
"""Request model for creating (or get-or-creating) a chat session.
|
||||||
|
|
||||||
|
Two modes, selected by the body:
|
||||||
|
|
||||||
|
- Default: create a fresh session. ``dry_run`` is a **top-level**
|
||||||
|
field — do not nest it inside ``metadata``.
|
||||||
|
- Builder-bound: when ``builder_graph_id`` is set, the endpoint
|
||||||
|
switches to **get-or-create** keyed on
|
||||||
|
``(user_id, builder_graph_id)``. The builder panel calls this on
|
||||||
|
mount so the chat persists across refreshes. Graph ownership is
|
||||||
|
validated inside :func:`get_or_create_builder_session`. Write-side
|
||||||
|
scope is enforced per-tool (``edit_agent`` / ``run_agent`` reject
|
||||||
|
any ``agent_id`` other than the bound graph) and a small blacklist
|
||||||
|
hides tools that conflict with the panel's scope
|
||||||
|
(``create_agent`` / ``customize_agent`` / ``get_agent_building_guide``
|
||||||
|
— see :data:`BUILDER_BLOCKED_TOOLS`). Read-side lookups
|
||||||
|
(``find_block``, ``find_agent``, ``search_docs``, …) stay open.
|
||||||
|
|
||||||
``dry_run`` is a **top-level** field — do not nest it inside ``metadata``.
|
|
||||||
Extra/unknown fields are rejected (422) to prevent silent mis-use.
|
Extra/unknown fields are rejected (422) to prevent silent mis-use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
dry_run: bool = False
|
dry_run: bool = False
|
||||||
|
builder_graph_id: str | None = Field(default=None, max_length=128)
|
||||||
|
|
||||||
|
|
||||||
class CreateSessionResponse(BaseModel):
|
class CreateSessionResponse(BaseModel):
|
||||||
@@ -178,6 +224,11 @@ class ActiveStreamInfo(BaseModel):
|
|||||||
|
|
||||||
turn_id: str
|
turn_id: str
|
||||||
last_message_id: str # Redis Stream message ID for resumption
|
last_message_id: str # Redis Stream message ID for resumption
|
||||||
|
# ISO-8601 timestamp (UTC) marking when the backend registered the turn
|
||||||
|
# as running. Lets the frontend seed its elapsed-time counter so restored
|
||||||
|
# turns show honest "time since turn started" instead of the misleading
|
||||||
|
# "time since this mount resumed the SSE".
|
||||||
|
started_at: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class SessionDetailResponse(BaseModel):
|
class SessionDetailResponse(BaseModel):
|
||||||
@@ -269,8 +320,11 @@ async def list_sessions(
|
|||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
pipe = redis.pipeline(transaction=False)
|
pipe = redis.pipeline(transaction=False)
|
||||||
for session in sessions:
|
for session in sessions:
|
||||||
|
# Use the canonical helper so the hash-tag braces match every
|
||||||
|
# other writer; building the key inline drops the braces and
|
||||||
|
# silently misses every running session on cluster mode.
|
||||||
pipe.hget(
|
pipe.hget(
|
||||||
f"{config.session_meta_prefix}{session.session_id}",
|
stream_registry.get_session_meta_key(session.session_id),
|
||||||
"status",
|
"status",
|
||||||
)
|
)
|
||||||
statuses = await pipe.execute()
|
statuses = await pipe.execute()
|
||||||
@@ -306,29 +360,43 @@ async def create_session(
|
|||||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||||
request: CreateSessionRequest | None = None,
|
request: CreateSessionRequest | None = None,
|
||||||
) -> CreateSessionResponse:
|
) -> CreateSessionResponse:
|
||||||
"""
|
"""Create (or get-or-create) a chat session.
|
||||||
Create a new chat session.
|
|
||||||
|
|
||||||
Initiates a new chat session for the authenticated user.
|
Two modes, selected by the request body:
|
||||||
|
|
||||||
|
- Default: create a fresh session for the user. ``dry_run=True`` forces
|
||||||
|
run_block and run_agent calls to use dry-run simulation.
|
||||||
|
- Builder-bound: when ``builder_graph_id`` is set, get-or-create keyed
|
||||||
|
on ``(user_id, builder_graph_id)``. Returns the existing session for
|
||||||
|
that graph or creates one locked to it. Graph ownership is validated
|
||||||
|
inside :func:`get_or_create_builder_session`; raises 404 on
|
||||||
|
unauthorized access. Write-side scope is enforced per-tool
|
||||||
|
(``edit_agent`` / ``run_agent`` reject any ``agent_id`` other than
|
||||||
|
the bound graph) and a small blacklist hides tools that conflict
|
||||||
|
with the panel's scope (see :data:`BUILDER_BLOCKED_TOOLS`).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: The authenticated user ID parsed from the JWT (required).
|
user_id: The authenticated user ID parsed from the JWT (required).
|
||||||
request: Optional request body. When provided, ``dry_run=True``
|
request: Optional request body with ``dry_run`` and/or
|
||||||
forces run_block and run_agent calls to use dry-run simulation.
|
``builder_graph_id``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
CreateSessionResponse: Details of the created session.
|
CreateSessionResponse: Details of the resulting session.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
dry_run = request.dry_run if request else False
|
dry_run = request.dry_run if request else False
|
||||||
|
builder_graph_id = request.builder_graph_id if request else None
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Creating session with user_id: "
|
f"Creating session with user_id: "
|
||||||
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
|
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
|
||||||
f"{', dry_run=True' if dry_run else ''}"
|
f"{', dry_run=True' if dry_run else ''}"
|
||||||
|
f"{f', builder_graph_id={builder_graph_id}' if builder_graph_id else ''}"
|
||||||
)
|
)
|
||||||
|
|
||||||
session = await create_chat_session(user_id, dry_run=dry_run)
|
if builder_graph_id:
|
||||||
|
session = await get_or_create_builder_session(user_id, builder_graph_id)
|
||||||
|
else:
|
||||||
|
session = await create_chat_session(user_id, dry_run=dry_run)
|
||||||
|
|
||||||
return CreateSessionResponse(
|
return CreateSessionResponse(
|
||||||
id=session.session_id,
|
id=session.session_id,
|
||||||
@@ -463,22 +531,13 @@ async def get_session(
|
|||||||
|
|
||||||
Supports cursor-based pagination via ``limit`` and ``before_sequence``.
|
Supports cursor-based pagination via ``limit`` and ``before_sequence``.
|
||||||
When no pagination params are provided, returns the most recent messages.
|
When no pagination params are provided, returns the most recent messages.
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: The unique identifier for the desired chat session.
|
|
||||||
user_id: The authenticated user's ID.
|
|
||||||
limit: Maximum number of messages to return (1-200, default 50).
|
|
||||||
before_sequence: Return messages with sequence < this value (cursor).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SessionDetailResponse: Details for the requested session, including
|
|
||||||
active_stream info and pagination metadata.
|
|
||||||
"""
|
"""
|
||||||
page = await get_chat_messages_paginated(
|
page = await get_chat_messages_paginated(
|
||||||
session_id, limit, before_sequence, user_id=user_id
|
session_id, limit, before_sequence, user_id=user_id
|
||||||
)
|
)
|
||||||
if page is None:
|
if page is None:
|
||||||
raise NotFoundError(f"Session {session_id} not found.")
|
raise NotFoundError(f"Session {session_id} not found.")
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
_strip_injected_context(message.model_dump()) for message in page.messages
|
_strip_injected_context(message.model_dump()) for message in page.messages
|
||||||
]
|
]
|
||||||
@@ -489,14 +548,11 @@ async def get_session(
|
|||||||
active_session, last_message_id = await stream_registry.get_active_session(
|
active_session, last_message_id = await stream_registry.get_active_session(
|
||||||
session_id, user_id
|
session_id, user_id
|
||||||
)
|
)
|
||||||
logger.info(
|
|
||||||
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
|
||||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
|
||||||
)
|
|
||||||
if active_session:
|
if active_session:
|
||||||
active_stream_info = ActiveStreamInfo(
|
active_stream_info = ActiveStreamInfo(
|
||||||
turn_id=active_session.turn_id,
|
turn_id=active_session.turn_id,
|
||||||
last_message_id=last_message_id,
|
last_message_id=last_message_id,
|
||||||
|
started_at=active_session.created_at.isoformat(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Skip session metadata on "load more" — frontend only needs messages
|
# Skip session metadata on "load more" — frontend only needs messages
|
||||||
@@ -537,23 +593,27 @@ async def get_session(
|
|||||||
)
|
)
|
||||||
async def get_copilot_usage(
|
async def get_copilot_usage(
|
||||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||||
) -> CoPilotUsageStatus:
|
) -> CoPilotUsagePublic:
|
||||||
"""Get CoPilot usage status for the authenticated user.
|
"""Get CoPilot usage status for the authenticated user.
|
||||||
|
|
||||||
Returns current token usage vs limits for daily and weekly windows.
|
Returns the percentage of the daily/weekly allowance used — not the
|
||||||
Global defaults sourced from LaunchDarkly (falling back to config).
|
raw spend or cap — so clients cannot derive per-turn cost or platform
|
||||||
Includes the user's rate-limit tier.
|
margins. Global defaults sourced from LaunchDarkly (falling back to
|
||||||
|
config). Includes the user's rate-limit tier.
|
||||||
"""
|
"""
|
||||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
user_id,
|
||||||
|
config.daily_cost_limit_microdollars,
|
||||||
|
config.weekly_cost_limit_microdollars,
|
||||||
)
|
)
|
||||||
return await get_usage_status(
|
status = await get_usage_status(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
daily_token_limit=daily_limit,
|
daily_cost_limit=daily_limit,
|
||||||
weekly_token_limit=weekly_limit,
|
weekly_cost_limit=weekly_limit,
|
||||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||||
tier=tier,
|
tier=tier,
|
||||||
)
|
)
|
||||||
|
return CoPilotUsagePublic.from_status(status)
|
||||||
|
|
||||||
|
|
||||||
class RateLimitResetResponse(BaseModel):
|
class RateLimitResetResponse(BaseModel):
|
||||||
@@ -562,7 +622,9 @@ class RateLimitResetResponse(BaseModel):
|
|||||||
success: bool
|
success: bool
|
||||||
credits_charged: int = Field(description="Credits charged (in cents)")
|
credits_charged: int = Field(description="Credits charged (in cents)")
|
||||||
remaining_balance: int = Field(description="Credit balance after charge (in cents)")
|
remaining_balance: int = Field(description="Credit balance after charge (in cents)")
|
||||||
usage: CoPilotUsageStatus = Field(description="Updated usage status after reset")
|
usage: CoPilotUsagePublic = Field(
|
||||||
|
description="Updated usage status after reset (percentages only)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -586,7 +648,7 @@ async def reset_copilot_usage(
|
|||||||
) -> RateLimitResetResponse:
|
) -> RateLimitResetResponse:
|
||||||
"""Reset the daily CoPilot rate limit by spending credits.
|
"""Reset the daily CoPilot rate limit by spending credits.
|
||||||
|
|
||||||
Allows users who have hit their daily token limit to spend credits
|
Allows users who have hit their daily cost limit to spend credits
|
||||||
to reset their daily usage counter and continue working.
|
to reset their daily usage counter and continue working.
|
||||||
Returns 400 if the feature is disabled or the user is not over the limit.
|
Returns 400 if the feature is disabled or the user is not over the limit.
|
||||||
Returns 402 if the user has insufficient credits.
|
Returns 402 if the user has insufficient credits.
|
||||||
@@ -605,7 +667,9 @@ async def reset_copilot_usage(
|
|||||||
)
|
)
|
||||||
|
|
||||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
user_id,
|
||||||
|
config.daily_cost_limit_microdollars,
|
||||||
|
config.weekly_cost_limit_microdollars,
|
||||||
)
|
)
|
||||||
|
|
||||||
if daily_limit <= 0:
|
if daily_limit <= 0:
|
||||||
@@ -642,8 +706,8 @@ async def reset_copilot_usage(
|
|||||||
# used for limit checks, not returned to the client.)
|
# used for limit checks, not returned to the client.)
|
||||||
usage_status = await get_usage_status(
|
usage_status = await get_usage_status(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
daily_token_limit=daily_limit,
|
daily_cost_limit=daily_limit,
|
||||||
weekly_token_limit=weekly_limit,
|
weekly_cost_limit=weekly_limit,
|
||||||
tier=tier,
|
tier=tier,
|
||||||
)
|
)
|
||||||
if daily_limit > 0 and usage_status.daily.used < daily_limit:
|
if daily_limit > 0 and usage_status.daily.used < daily_limit:
|
||||||
@@ -678,7 +742,7 @@ async def reset_copilot_usage(
|
|||||||
|
|
||||||
# Reset daily usage in Redis. If this fails, refund the credits
|
# Reset daily usage in Redis. If this fails, refund the credits
|
||||||
# so the user is not charged for a service they did not receive.
|
# so the user is not charged for a service they did not receive.
|
||||||
if not await reset_daily_usage(user_id, daily_token_limit=daily_limit):
|
if not await reset_daily_usage(user_id, daily_cost_limit=daily_limit):
|
||||||
# Compensate: refund the charged credits.
|
# Compensate: refund the charged credits.
|
||||||
refunded = False
|
refunded = False
|
||||||
try:
|
try:
|
||||||
@@ -714,11 +778,11 @@ async def reset_copilot_usage(
|
|||||||
finally:
|
finally:
|
||||||
await release_reset_lock(user_id)
|
await release_reset_lock(user_id)
|
||||||
|
|
||||||
# Return updated usage status.
|
# Return updated usage status (public schema — percentages only).
|
||||||
updated_usage = await get_usage_status(
|
updated_usage = await get_usage_status(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
daily_token_limit=daily_limit,
|
daily_cost_limit=daily_limit,
|
||||||
weekly_token_limit=weekly_limit,
|
weekly_cost_limit=weekly_limit,
|
||||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||||
tier=tier,
|
tier=tier,
|
||||||
)
|
)
|
||||||
@@ -727,7 +791,7 @@ async def reset_copilot_usage(
|
|||||||
success=True,
|
success=True,
|
||||||
credits_charged=cost,
|
credits_charged=cost,
|
||||||
remaining_balance=remaining,
|
remaining_balance=remaining,
|
||||||
usage=updated_usage,
|
usage=CoPilotUsagePublic.from_status(updated_usage),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -776,38 +840,81 @@ async def cancel_session_task(
|
|||||||
return CancelSessionResponse(cancelled=True)
|
return CancelSessionResponse(cancelled=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _ui_message_stream_headers() -> dict[str, str]:
|
||||||
|
return {
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"x-vercel-ai-ui-message-stream": "v1",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _empty_ui_message_stream_response() -> StreamingResponse:
|
||||||
|
# Stable placeholder messageId for the empty queued-mid-turn stream.
|
||||||
|
# Real turns generate per-message UUIDs via the executor; this stream
|
||||||
|
# has no message to attach to, but the AI SDK parser still requires a
|
||||||
|
# non-empty ``messageId`` field on ``StreamStart``.
|
||||||
|
message_id = uuid4().hex
|
||||||
|
|
||||||
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
|
# Vercel AI SDK's UI-message-stream parser expects symmetric
|
||||||
|
# start/finish framing at both stream and step level — every
|
||||||
|
# non-empty turn emits the pair. Without an opener, today's parser
|
||||||
|
# tolerates the closer (no active parts to flush) but a future SDK
|
||||||
|
# tightening would silently break the queue-mid-turn UX. Emit the
|
||||||
|
# full empty pair so the contract stays correct.
|
||||||
|
yield StreamStart(messageId=message_id).to_sse()
|
||||||
|
yield StreamStartStep().to_sse()
|
||||||
|
yield StreamFinishStep().to_sse()
|
||||||
|
yield StreamFinish().to_sse()
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers=_ui_message_stream_headers(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/sessions/{session_id}/stream",
|
"/sessions/{session_id}/stream",
|
||||||
|
responses={
|
||||||
|
404: {"description": "Session not found or access denied"},
|
||||||
|
429: {"description": "Cost rate-limit or call-frequency cap exceeded"},
|
||||||
|
},
|
||||||
)
|
)
|
||||||
async def stream_chat_post(
|
async def stream_chat_post(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
request: StreamChatRequest,
|
request: StreamChatRequest,
|
||||||
user_id: str = Security(auth.get_user_id),
|
user_id: str = Security(auth.get_user_id),
|
||||||
):
|
):
|
||||||
"""
|
"""Start a new turn and return an AI SDK UI message stream.
|
||||||
Stream chat responses for a session (POST with context support).
|
|
||||||
|
|
||||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
Returns an SSE stream (``text/event-stream``) with Vercel AI SDK chunks
|
||||||
- Text fragments as they are generated
|
(text fragments, tool-call UI, tool results). The generation runs in a
|
||||||
- Tool call UI elements (if invoked)
|
background task that survives client disconnects; reconnect via
|
||||||
- Tool execution results
|
``GET /sessions/{session_id}/stream`` to resume.
|
||||||
|
|
||||||
The AI generation runs in a background task that continues even if the client disconnects.
|
Follow-up messages typed while a turn is already running should use
|
||||||
All chunks are written to a per-turn Redis stream for reconnection support. If the client
|
``POST /sessions/{session_id}/messages/pending``. If an older client still
|
||||||
disconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.
|
posts that follow-up here, we queue it defensively but still return a valid
|
||||||
|
empty UI-message stream so AI SDK transports never receive a JSON body from
|
||||||
|
the stream endpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The chat session identifier to associate with the streamed messages.
|
session_id: The chat session identifier.
|
||||||
request: Request body containing message, is_user_message, and optional context.
|
request: Request body with message, is_user_message, and optional context.
|
||||||
user_id: Authenticated user ID.
|
user_id: Authenticated user ID.
|
||||||
Returns:
|
|
||||||
StreamingResponse: SSE-formatted response chunks.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
|
|
||||||
stream_start_time = time.perf_counter()
|
stream_start_time = time.perf_counter()
|
||||||
|
# Wall-clock arrival time, propagated to the executor so the turn-start
|
||||||
|
# drain can order pending messages relative to this request (pending
|
||||||
|
# pushed BEFORE this instant were typed earlier; pending pushed AFTER
|
||||||
|
# are race-path follow-ups typed while /stream was still processing).
|
||||||
|
request_arrival_at = time.time()
|
||||||
log_meta = {"component": "ChatStream", "session_id": session_id, "user_id": user_id}
|
log_meta = {"component": "ChatStream", "session_id": session_id, "user_id": user_id}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -815,7 +922,31 @@ async def stream_chat_post(
|
|||||||
f"user={user_id}, message_len={len(request.message)}",
|
f"user={user_id}, message_len={len(request.message)}",
|
||||||
extra={"json_fields": log_meta},
|
extra={"json_fields": log_meta},
|
||||||
)
|
)
|
||||||
await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
|
if (
|
||||||
|
request.is_user_message
|
||||||
|
and request.message
|
||||||
|
and await is_turn_in_flight(session_id)
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
await queue_pending_for_http(
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
message=request.message,
|
||||||
|
context=request.context,
|
||||||
|
file_ids=request.file_ids,
|
||||||
|
)
|
||||||
|
return _empty_ui_message_stream_response()
|
||||||
|
except HTTPException as exc:
|
||||||
|
if exc.status_code != 409:
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Permission resolution is only needed below for the actual turn — keep
|
||||||
|
# it after the queue-fall-through so a queued mid-turn request returns
|
||||||
|
# without paying the work.
|
||||||
|
builder_permissions = resolve_session_permissions(session)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
||||||
extra={
|
extra={
|
||||||
@@ -826,18 +957,20 @@ async def stream_chat_post(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pre-turn rate limit check (token-based).
|
# Pre-turn rate limit check (cost-based, microdollars).
|
||||||
# check_rate_limit short-circuits internally when both limits are 0.
|
# check_rate_limit short-circuits internally when both limits are 0.
|
||||||
# Global defaults sourced from LaunchDarkly, falling back to config.
|
# Global defaults sourced from LaunchDarkly, falling back to config.
|
||||||
if user_id:
|
if user_id:
|
||||||
try:
|
try:
|
||||||
daily_limit, weekly_limit, _ = await get_global_rate_limits(
|
daily_limit, weekly_limit, _ = await get_global_rate_limits(
|
||||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
user_id,
|
||||||
|
config.daily_cost_limit_microdollars,
|
||||||
|
config.weekly_cost_limit_microdollars,
|
||||||
)
|
)
|
||||||
await check_rate_limit(
|
await check_rate_limit(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
daily_token_limit=daily_limit,
|
daily_cost_limit=daily_limit,
|
||||||
weekly_token_limit=weekly_limit,
|
weekly_cost_limit=weekly_limit,
|
||||||
)
|
)
|
||||||
except RateLimitExceeded as e:
|
except RateLimitExceeded as e:
|
||||||
raise HTTPException(status_code=429, detail=str(e)) from e
|
raise HTTPException(status_code=429, detail=str(e)) from e
|
||||||
@@ -846,89 +979,41 @@ async def stream_chat_post(
|
|||||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||||
sanitized_file_ids: list[str] | None = None
|
sanitized_file_ids: list[str] | None = None
|
||||||
# Capture the original message text BEFORE any mutation (attachment enrichment)
|
if request.file_ids:
|
||||||
# so the idempotency hash is stable across retries.
|
files = await resolve_workspace_files(user_id, request.file_ids)
|
||||||
original_message = request.message
|
sanitized_file_ids = [wf.id for wf in files] or None
|
||||||
if request.file_ids and user_id:
|
request.message += build_files_block(files)
|
||||||
# Filter to valid UUIDs only to prevent DB abuse
|
|
||||||
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
|
||||||
|
|
||||||
if valid_ids:
|
|
||||||
workspace = await get_or_create_workspace(user_id)
|
|
||||||
# Batch query instead of N+1
|
|
||||||
files = await UserWorkspaceFile.prisma().find_many(
|
|
||||||
where={
|
|
||||||
"id": {"in": valid_ids},
|
|
||||||
"workspaceId": workspace.id,
|
|
||||||
"isDeleted": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
# Only keep IDs that actually exist in the user's workspace
|
|
||||||
sanitized_file_ids = [wf.id for wf in files] or None
|
|
||||||
file_lines: list[str] = [
|
|
||||||
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
|
||||||
for wf in files
|
|
||||||
]
|
|
||||||
if file_lines:
|
|
||||||
files_block = (
|
|
||||||
"\n\n[Attached files]\n"
|
|
||||||
+ "\n".join(file_lines)
|
|
||||||
+ "\nUse read_workspace_file with the file_id to access file contents."
|
|
||||||
)
|
|
||||||
request.message += files_block
|
|
||||||
|
|
||||||
# ── Idempotency guard ────────────────────────────────────────────────────
|
|
||||||
# Blocks duplicate executor tasks from concurrent/retried POSTs.
|
|
||||||
# See backend/copilot/message_dedup.py for the full lifecycle description.
|
|
||||||
dedup_lock = None
|
|
||||||
if request.is_user_message:
|
|
||||||
dedup_lock = await acquire_dedup_lock(
|
|
||||||
session_id, original_message, sanitized_file_ids
|
|
||||||
)
|
|
||||||
if dedup_lock is None and (original_message or sanitized_file_ids):
|
|
||||||
|
|
||||||
async def _empty_sse() -> AsyncGenerator[str, None]:
|
|
||||||
yield StreamFinish().to_sse()
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
_empty_sse(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"x-vercel-ai-ui-message-stream": "v1",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Atomically append user message to session BEFORE creating task to avoid
|
# Atomically append user message to session BEFORE creating task to avoid
|
||||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||||
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
# saved yet. append_and_save_message returns None when a duplicate is
|
||||||
# message loss from concurrent requests.
|
# detected — in that case skip enqueue to avoid processing the message twice.
|
||||||
#
|
is_duplicate_message = False
|
||||||
# If any of these operations raises, release the dedup lock before propagating
|
if request.message:
|
||||||
# so subsequent retries are not blocked for 30 s.
|
message = ChatMessage(
|
||||||
try:
|
role="user" if request.is_user_message else "assistant",
|
||||||
if request.message:
|
content=request.message,
|
||||||
message = ChatMessage(
|
)
|
||||||
role="user" if request.is_user_message else "assistant",
|
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||||
content=request.message,
|
is_duplicate_message = (
|
||||||
)
|
|
||||||
if request.is_user_message:
|
|
||||||
track_user_message(
|
|
||||||
user_id=user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
message_length=len(request.message),
|
|
||||||
)
|
|
||||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
|
||||||
await append_and_save_message(session_id, message)
|
await append_and_save_message(session_id, message)
|
||||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
) is None
|
||||||
|
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||||
|
if not is_duplicate_message and request.is_user_message:
|
||||||
|
track_user_message(
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
message_length=len(request.message),
|
||||||
|
)
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
# Create a task in the stream registry for reconnection support.
|
||||||
|
# For duplicate messages, skip create_session entirely so the infra-retry
|
||||||
|
# client subscribes to the *existing* turn's Redis stream and receives the
|
||||||
|
# in-progress executor output rather than an empty stream.
|
||||||
|
turn_id = ""
|
||||||
|
if not is_duplicate_message:
|
||||||
turn_id = str(uuid4())
|
turn_id = str(uuid4())
|
||||||
log_meta["turn_id"] = turn_id
|
log_meta["turn_id"] = turn_id
|
||||||
|
|
||||||
session_create_start = time.perf_counter()
|
session_create_start = time.perf_counter()
|
||||||
await stream_registry.create_session(
|
await stream_registry.create_session(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -946,7 +1031,6 @@ async def stream_chat_post(
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
await enqueue_copilot_turn(
|
await enqueue_copilot_turn(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -957,11 +1041,13 @@ async def stream_chat_post(
|
|||||||
file_ids=sanitized_file_ids,
|
file_ids=sanitized_file_ids,
|
||||||
mode=request.mode,
|
mode=request.mode,
|
||||||
model=request.model,
|
model=request.model,
|
||||||
|
permissions=builder_permissions,
|
||||||
|
request_arrival_at=request_arrival_at,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"[STREAM] Duplicate message detected for session {session_id}, skipping enqueue"
|
||||||
)
|
)
|
||||||
except Exception:
|
|
||||||
if dedup_lock:
|
|
||||||
await dedup_lock.release()
|
|
||||||
raise
|
|
||||||
|
|
||||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -985,12 +1071,6 @@ async def stream_chat_post(
|
|||||||
subscriber_queue = None
|
subscriber_queue = None
|
||||||
first_chunk_yielded = False
|
first_chunk_yielded = False
|
||||||
chunks_yielded = 0
|
chunks_yielded = 0
|
||||||
# True for every exit path except GeneratorExit (client disconnect).
|
|
||||||
# On disconnect the backend turn is still running — releasing the lock
|
|
||||||
# there would reopen the infra-retry duplicate window. The 30 s TTL
|
|
||||||
# is the fallback. All other exits (normal finish, early return, error)
|
|
||||||
# should release so the user can re-send the same message.
|
|
||||||
release_dedup_lock_on_exit = True
|
|
||||||
try:
|
try:
|
||||||
# Subscribe from the position we captured before enqueuing
|
# Subscribe from the position we captured before enqueuing
|
||||||
# This avoids replaying old messages while catching all new ones
|
# This avoids replaying old messages while catching all new ones
|
||||||
@@ -1002,7 +1082,7 @@ async def stream_chat_post(
|
|||||||
|
|
||||||
if subscriber_queue is None:
|
if subscriber_queue is None:
|
||||||
yield StreamFinish().to_sse()
|
yield StreamFinish().to_sse()
|
||||||
return # finally releases dedup_lock
|
return
|
||||||
|
|
||||||
# Read from the subscriber queue and yield to SSE
|
# Read from the subscriber queue and yield to SSE
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -1044,7 +1124,7 @@ async def stream_chat_post(
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
break # finally releases dedup_lock
|
break
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
yield StreamHeartbeat().to_sse()
|
yield StreamHeartbeat().to_sse()
|
||||||
@@ -1060,7 +1140,6 @@ async def stream_chat_post(
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
release_dedup_lock_on_exit = False
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -1075,10 +1154,7 @@ async def stream_chat_post(
|
|||||||
code="stream_error",
|
code="stream_error",
|
||||||
).to_sse()
|
).to_sse()
|
||||||
yield StreamFinish().to_sse()
|
yield StreamFinish().to_sse()
|
||||||
# finally releases dedup_lock
|
|
||||||
finally:
|
finally:
|
||||||
if dedup_lock and release_dedup_lock_on_exit:
|
|
||||||
await dedup_lock.release()
|
|
||||||
# Unsubscribe when client disconnects or stream ends
|
# Unsubscribe when client disconnects or stream ends
|
||||||
if subscriber_queue is not None:
|
if subscriber_queue is not None:
|
||||||
try:
|
try:
|
||||||
@@ -1108,12 +1184,62 @@ async def stream_chat_post(
|
|||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
event_generator(),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers=_ui_message_stream_headers(),
|
||||||
"Cache-Control": "no-cache",
|
)
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
|
||||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
@router.post(
|
||||||
},
|
"/sessions/{session_id}/messages/pending",
|
||||||
|
response_model=QueuePendingMessageResponse,
|
||||||
|
responses={
|
||||||
|
404: {"description": "Session not found or access denied"},
|
||||||
|
409: {"description": "Session has no active turn to receive pending messages"},
|
||||||
|
429: {"description": "Call-frequency cap exceeded"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def queue_pending_message(
|
||||||
|
session_id: str,
|
||||||
|
request: QueuePendingMessageRequest,
|
||||||
|
user_id: str = Security(auth.get_user_id),
|
||||||
|
):
|
||||||
|
"""Queue a follow-up message while the session has an active turn."""
|
||||||
|
await _validate_and_get_session(session_id, user_id)
|
||||||
|
if not await is_turn_in_flight(session_id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail="Session has no active turn. Start a new turn with POST /stream.",
|
||||||
|
)
|
||||||
|
return await queue_pending_for_http(
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
message=request.message,
|
||||||
|
context=request.context,
|
||||||
|
file_ids=request.file_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/sessions/{session_id}/messages/pending",
|
||||||
|
response_model=PeekPendingMessagesResponse,
|
||||||
|
responses={
|
||||||
|
404: {"description": "Session not found or access denied"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_pending_messages(
|
||||||
|
session_id: str,
|
||||||
|
user_id: str = Security(auth.get_user_id),
|
||||||
|
):
|
||||||
|
"""Peek at the pending-message buffer without consuming it.
|
||||||
|
|
||||||
|
Returns the current contents of the session's pending message buffer
|
||||||
|
so the frontend can restore the queued-message indicator after a page
|
||||||
|
refresh and clear it correctly once a turn drains the buffer.
|
||||||
|
"""
|
||||||
|
await _validate_and_get_session(session_id, user_id)
|
||||||
|
pending = await peek_pending_messages(session_id)
|
||||||
|
return PeekPendingMessagesResponse(
|
||||||
|
messages=[m.content for m in pending],
|
||||||
|
count=len(pending),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1122,6 +1248,7 @@ async def stream_chat_post(
|
|||||||
)
|
)
|
||||||
async def resume_session_stream(
|
async def resume_session_stream(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
last_chunk_id: str | None = Query(default=None, include_in_schema=False),
|
||||||
user_id: str = Security(auth.get_user_id),
|
user_id: str = Security(auth.get_user_id),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -1131,27 +1258,26 @@ async def resume_session_stream(
|
|||||||
Checks for an active (in-progress) task on the session and either replays
|
Checks for an active (in-progress) task on the session and either replays
|
||||||
the full SSE stream or returns 204 No Content if nothing is running.
|
the full SSE stream or returns 204 No Content if nothing is running.
|
||||||
|
|
||||||
Args:
|
Always replays the active turn from ``0-0``. The AI SDK UI-message parser
|
||||||
session_id: The chat session identifier.
|
keeps text/reasoning part state inside a single parser instance; resuming
|
||||||
user_id: Optional authenticated user ID.
|
from a Redis cursor can skip the ``*-start`` events required by later
|
||||||
|
``*-delta`` chunks.
|
||||||
Returns:
|
|
||||||
StreamingResponse (SSE) when an active stream exists,
|
|
||||||
or 204 No Content when there is nothing to resume.
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
active_session, last_message_id = await stream_registry.get_active_session(
|
active_session, _latest_backend_id = await stream_registry.get_active_session(
|
||||||
session_id, user_id
|
session_id, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if not active_session:
|
if not active_session:
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
|
|
||||||
# Always replay from the beginning ("0-0") on resume.
|
if last_chunk_id:
|
||||||
# We can't use last_message_id because it's the latest ID in the backend
|
logger.info(
|
||||||
# stream, not the latest the frontend received — the gap causes lost
|
"Ignoring deprecated last_chunk_id on stream resume",
|
||||||
# messages. The frontend deduplicates replayed content.
|
extra={"session_id": session_id, "last_chunk_id": last_chunk_id},
|
||||||
|
)
|
||||||
|
|
||||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -1212,12 +1338,7 @@ async def resume_session_stream(
|
|||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
event_generator(),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers=_ui_message_stream_headers(),
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
"x-vercel-ai-ui-message-stream": "v1",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1373,6 +1494,7 @@ ToolResponseUnion = (
|
|||||||
| MemorySearchResponse
|
| MemorySearchResponse
|
||||||
| MemoryForgetCandidatesResponse
|
| MemoryForgetCandidatesResponse
|
||||||
| MemoryForgetConfirmResponse
|
| MemoryForgetConfirmResponse
|
||||||
|
| TodoWriteResponse
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -7,6 +7,7 @@ allowing frontend code generators like Orval to create corresponding TypeScript
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from backend.data.model import CredentialsType
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.sdk.registry import AutoRegistry
|
from backend.sdk.registry import AutoRegistry
|
||||||
|
|
||||||
@@ -47,6 +48,57 @@ class ProviderNamesResponse(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderMetadata(BaseModel):
|
||||||
|
"""Display metadata for a provider, shown in the settings integrations UI."""
|
||||||
|
|
||||||
|
name: str = Field(description="Provider slug (e.g. ``github``)")
|
||||||
|
description: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"One-line human-readable summary of what the provider does. "
|
||||||
|
"Declared via ``ProviderBuilder.with_description(...)`` in the "
|
||||||
|
"provider's ``_config.py``. ``None`` if not set."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
supported_auth_types: list[CredentialsType] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description=(
|
||||||
|
"Credential types this provider accepts. Drives which connection "
|
||||||
|
"tabs the settings UI renders for the provider. Empty list means "
|
||||||
|
"no auth types declared."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_supported_auth_types(name: str) -> list[CredentialsType]:
|
||||||
|
"""Return the provider's supported credential types from :class:`AutoRegistry`.
|
||||||
|
|
||||||
|
Populated by :meth:`ProviderBuilder.with_supported_auth_types` (or by
|
||||||
|
``with_oauth`` / ``with_api_key`` / ``with_user_password`` when the provider
|
||||||
|
uses the full builder chain). Returns an empty list for providers with no
|
||||||
|
auth types declared.
|
||||||
|
"""
|
||||||
|
provider = AutoRegistry.get_provider(name)
|
||||||
|
if provider is None:
|
||||||
|
return []
|
||||||
|
return sorted(provider.supported_auth_types)
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_description(name: str) -> str | None:
|
||||||
|
"""Return the provider's description from :class:`AutoRegistry`.
|
||||||
|
|
||||||
|
Descriptions are declared via ``ProviderBuilder.with_description(...)`` in
|
||||||
|
the provider's ``_config.py`` (SDK path) or in
|
||||||
|
``blocks/_static_provider_configs.py`` (for providers that don't yet have
|
||||||
|
their own directory). Returns ``None`` for providers with no registered
|
||||||
|
description.
|
||||||
|
"""
|
||||||
|
provider = AutoRegistry.get_provider(name)
|
||||||
|
if provider is None:
|
||||||
|
return None
|
||||||
|
return provider.description
|
||||||
|
|
||||||
|
|
||||||
class ProviderConstants(BaseModel):
|
class ProviderConstants(BaseModel):
|
||||||
"""
|
"""
|
||||||
Model that exposes all provider names as a constant in the OpenAPI schema.
|
Model that exposes all provider names as a constant in the OpenAPI schema.
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from fastapi import (
|
|||||||
Security,
|
Security,
|
||||||
status,
|
status,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, Field, SecretStr, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
||||||
|
|
||||||
from backend.api.features.library.db import set_preset_webhook, update_preset
|
from backend.api.features.library.db import set_preset_webhook, update_preset
|
||||||
@@ -29,15 +29,14 @@ from backend.data.integrations import (
|
|||||||
wait_for_webhook_event,
|
wait_for_webhook_event,
|
||||||
)
|
)
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
|
APIKeyCredentials,
|
||||||
Credentials,
|
Credentials,
|
||||||
CredentialsType,
|
CredentialsType,
|
||||||
HostScopedCredentials,
|
HostScopedCredentials,
|
||||||
OAuth2Credentials,
|
OAuth2Credentials,
|
||||||
UserIntegrations,
|
|
||||||
is_sdk_default,
|
is_sdk_default,
|
||||||
)
|
)
|
||||||
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||||
from backend.data.user import get_user_integrations
|
|
||||||
from backend.executor.utils import add_graph_execution
|
from backend.executor.utils import add_graph_execution
|
||||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||||
from backend.integrations.credentials_store import (
|
from backend.integrations.credentials_store import (
|
||||||
@@ -48,7 +47,14 @@ from backend.integrations.creds_manager import (
|
|||||||
IntegrationCredentialsManager,
|
IntegrationCredentialsManager,
|
||||||
create_mcp_oauth_handler,
|
create_mcp_oauth_handler,
|
||||||
)
|
)
|
||||||
from backend.integrations.managed_credentials import ensure_managed_credentials
|
from backend.integrations.managed_credentials import (
|
||||||
|
ensure_managed_credential,
|
||||||
|
ensure_managed_credentials,
|
||||||
|
)
|
||||||
|
from backend.integrations.managed_providers.ayrshare import AyrshareManagedProvider
|
||||||
|
from backend.integrations.managed_providers.ayrshare import (
|
||||||
|
settings_available as ayrshare_settings_available,
|
||||||
|
)
|
||||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.integrations.webhooks import get_webhook_manager
|
from backend.integrations.webhooks import get_webhook_manager
|
||||||
@@ -60,7 +66,14 @@ from backend.util.exceptions import (
|
|||||||
)
|
)
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
from .models import ProviderConstants, ProviderNamesResponse, get_all_provider_names
|
from .models import (
|
||||||
|
ProviderConstants,
|
||||||
|
ProviderMetadata,
|
||||||
|
ProviderNamesResponse,
|
||||||
|
get_all_provider_names,
|
||||||
|
get_provider_description,
|
||||||
|
get_supported_auth_types,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.integrations.oauth import BaseOAuthHandler
|
from backend.integrations.oauth import BaseOAuthHandler
|
||||||
@@ -87,14 +100,23 @@ async def login(
|
|||||||
scopes: Annotated[
|
scopes: Annotated[
|
||||||
str, Query(title="Comma-separated list of authorization scopes")
|
str, Query(title="Comma-separated list of authorization scopes")
|
||||||
] = "",
|
] = "",
|
||||||
|
credential_id: Annotated[
|
||||||
|
str | None,
|
||||||
|
Query(title="ID of existing credential to upgrade scopes for"),
|
||||||
|
] = None,
|
||||||
) -> LoginResponse:
|
) -> LoginResponse:
|
||||||
handler = _get_provider_oauth_handler(request, provider)
|
handler = _get_provider_oauth_handler(request, provider)
|
||||||
|
|
||||||
requested_scopes = scopes.split(",") if scopes else []
|
requested_scopes = scopes.split(",") if scopes else []
|
||||||
|
|
||||||
|
if credential_id:
|
||||||
|
requested_scopes = await _prepare_scope_upgrade(
|
||||||
|
user_id, provider, credential_id, requested_scopes
|
||||||
|
)
|
||||||
|
|
||||||
# Generate and store a secure random state token along with the scopes
|
# Generate and store a secure random state token along with the scopes
|
||||||
state_token, code_challenge = await creds_manager.store.store_state_token(
|
state_token, code_challenge = await creds_manager.store.store_state_token(
|
||||||
user_id, provider, requested_scopes
|
user_id, provider, requested_scopes, credential_id=credential_id
|
||||||
)
|
)
|
||||||
login_url = handler.get_login_url(
|
login_url = handler.get_login_url(
|
||||||
requested_scopes, state_token, code_challenge=code_challenge
|
requested_scopes, state_token, code_challenge=code_challenge
|
||||||
@@ -216,7 +238,9 @@ async def callback(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Allow specifying `title` to set on `credentials`
|
# TODO: Allow specifying `title` to set on `credentials`
|
||||||
await creds_manager.create(user_id, credentials)
|
credentials = await _merge_or_create_credential(
|
||||||
|
user_id, provider, credentials, valid_state.credential_id
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Successfully processed OAuth callback for user {user_id} "
|
f"Successfully processed OAuth callback for user {user_id} "
|
||||||
@@ -226,13 +250,38 @@ async def callback(
|
|||||||
return to_meta_response(credentials)
|
return to_meta_response(credentials)
|
||||||
|
|
||||||
|
|
||||||
|
# Bound the first-time sweep so a slow upstream (e.g. Ayrshare) can't hang
|
||||||
|
# the credential-list endpoint. On timeout we still kick off a fire-and-
|
||||||
|
# forget sweep so provisioning eventually completes; the user just won't
|
||||||
|
# see the managed cred until the next refresh.
|
||||||
|
_MANAGED_PROVISION_TIMEOUT_S = 10.0
|
||||||
|
|
||||||
|
|
||||||
|
async def _ensure_managed_credentials_bounded(user_id: str) -> None:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
ensure_managed_credentials(user_id, creds_manager.store),
|
||||||
|
timeout=_MANAGED_PROVISION_TIMEOUT_S,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
"Managed credential sweep exceeded %.1fs for user=%s; "
|
||||||
|
"continuing without it — provisioning will complete in background",
|
||||||
|
_MANAGED_PROVISION_TIMEOUT_S,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/credentials", summary="List Credentials")
|
@router.get("/credentials", summary="List Credentials")
|
||||||
async def list_credentials(
|
async def list_credentials(
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
) -> list[CredentialsMetaResponse]:
|
) -> list[CredentialsMetaResponse]:
|
||||||
# Fire-and-forget: provision missing managed credentials in the background.
|
# Block on provisioning so managed credentials appear on the first load
|
||||||
# The credential appears on the next page load; listing is never blocked.
|
# instead of after a refresh, but with a timeout so a slow upstream
|
||||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
# can't hang the endpoint. `_provisioned_users` short-circuits on
|
||||||
|
# repeat calls.
|
||||||
|
await _ensure_managed_credentials_bounded(user_id)
|
||||||
credentials = await creds_manager.store.get_all_creds(user_id)
|
credentials = await creds_manager.store.get_all_creds(user_id)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
@@ -247,7 +296,7 @@ async def list_credentials_by_provider(
|
|||||||
],
|
],
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
) -> list[CredentialsMetaResponse]:
|
) -> list[CredentialsMetaResponse]:
|
||||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
await _ensure_managed_credentials_bounded(user_id)
|
||||||
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
@@ -281,6 +330,115 @@ async def get_credential(
|
|||||||
return to_meta_response(credential)
|
return to_meta_response(credential)
|
||||||
|
|
||||||
|
|
||||||
|
class PickerTokenResponse(BaseModel):
|
||||||
|
"""Short-lived OAuth access token shipped to the browser for rendering a
|
||||||
|
provider-hosted picker UI (e.g. Google Drive Picker). Deliberately narrow:
|
||||||
|
only the fields the client needs to initialize the picker widget. Issued
|
||||||
|
from the user's own stored credential so ownership and scope gating are
|
||||||
|
enforced by the credential lookup."""
|
||||||
|
|
||||||
|
access_token: str = Field(
|
||||||
|
description="OAuth access token suitable for the picker SDK call."
|
||||||
|
)
|
||||||
|
access_token_expires_at: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Unix timestamp at which the access token expires, if known.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Allowlist of (provider, scopes) tuples that may mint picker tokens. Only
|
||||||
|
# Drive-picker-capable scopes qualify so a caller can't use this endpoint to
|
||||||
|
# extract a GitHub / other-provider OAuth token for unrelated purposes. If a
|
||||||
|
# future provider integrates a hosted picker that needs a raw access token,
|
||||||
|
# add its specific picker-relevant scopes here.
|
||||||
|
_PICKER_TOKEN_ALLOWED_SCOPES: dict[ProviderName, frozenset[str]] = {
|
||||||
|
ProviderName.GOOGLE: frozenset(
|
||||||
|
[
|
||||||
|
"https://www.googleapis.com/auth/drive.file",
|
||||||
|
"https://www.googleapis.com/auth/drive.readonly",
|
||||||
|
"https://www.googleapis.com/auth/drive",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{provider}/credentials/{cred_id}/picker-token",
|
||||||
|
summary="Issue a short-lived access token for a provider-hosted picker",
|
||||||
|
operation_id="postV1GetPickerToken",
|
||||||
|
)
|
||||||
|
async def get_picker_token(
|
||||||
|
provider: Annotated[
|
||||||
|
ProviderName, Path(title="The provider that owns the credentials")
|
||||||
|
],
|
||||||
|
cred_id: Annotated[
|
||||||
|
str, Path(title="The ID of the OAuth2 credentials to mint a token from")
|
||||||
|
],
|
||||||
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
) -> PickerTokenResponse:
|
||||||
|
"""Return the raw access token for an OAuth2 credential so the frontend
|
||||||
|
can initialize a provider-hosted picker (e.g. Google Drive Picker).
|
||||||
|
|
||||||
|
`GET /{provider}/credentials/{cred_id}` deliberately strips secrets (see
|
||||||
|
`CredentialsMetaResponse` + `TestGetCredentialReturnsMetaOnly` in
|
||||||
|
`router_test.py`). That hardening broke the Drive picker, which needs the
|
||||||
|
raw access token to call `google.picker.Builder.setOAuthToken(...)`. This
|
||||||
|
endpoint carves a narrow, explicit hole: the caller must own the
|
||||||
|
credential, it must be OAuth2, and the endpoint returns only the access
|
||||||
|
token + its expiry — nothing else about the credential. SDK-default
|
||||||
|
credentials are excluded for the same reason as `get_credential`.
|
||||||
|
"""
|
||||||
|
if is_sdk_default(cred_id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
credential = await creds_manager.get(user_id, cred_id)
|
||||||
|
if not credential:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||||
|
)
|
||||||
|
if not provider_matches(credential.provider, provider):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||||
|
)
|
||||||
|
if not isinstance(credential, OAuth2Credentials):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Picker tokens are only available for OAuth2 credentials",
|
||||||
|
)
|
||||||
|
if not credential.access_token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Credential has no access token; reconnect the account",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gate on provider+scope: only credentials that actually grant access to
|
||||||
|
# a provider-hosted picker flow may mint a token through this endpoint.
|
||||||
|
# Prevents using this path to extract bearer tokens for unrelated OAuth
|
||||||
|
# integrations (e.g. GitHub) that happen to be stored under the same user.
|
||||||
|
allowed_scopes = _PICKER_TOKEN_ALLOWED_SCOPES.get(provider)
|
||||||
|
if not allowed_scopes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=(f"Picker tokens are not available for provider '{provider.value}'"),
|
||||||
|
)
|
||||||
|
cred_scopes = set(credential.scopes or [])
|
||||||
|
if cred_scopes.isdisjoint(allowed_scopes):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=(
|
||||||
|
"Credential does not grant any scope eligible for the picker. "
|
||||||
|
"Reconnect with the appropriate scope."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return PickerTokenResponse(
|
||||||
|
access_token=credential.access_token.get_secret_value(),
|
||||||
|
access_token_expires_at=credential.access_token_expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{provider}/credentials", status_code=201, summary="Create Credentials")
|
@router.post("/{provider}/credentials", status_code=201, summary="Create Credentials")
|
||||||
async def create_credentials(
|
async def create_credentials(
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
@@ -574,6 +732,186 @@ async def _execute_webhook_preset_trigger(
|
|||||||
# Continue processing - webhook should be resilient to individual failures
|
# Continue processing - webhook should be resilient to individual failures
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------- INCREMENTAL AUTH HELPERS -------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
async def _prepare_scope_upgrade(
|
||||||
|
user_id: str,
|
||||||
|
provider: ProviderName,
|
||||||
|
credential_id: str,
|
||||||
|
requested_scopes: list[str],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Validate an existing credential for scope upgrade and compute scopes.
|
||||||
|
|
||||||
|
For providers without native incremental auth (e.g. GitHub), returns the
|
||||||
|
union of existing + requested scopes. For providers that handle merging
|
||||||
|
server-side (e.g. Google with ``include_granted_scopes``), returns the
|
||||||
|
requested scopes unchanged.
|
||||||
|
|
||||||
|
Raises HTTPException on validation failure.
|
||||||
|
"""
|
||||||
|
# Platform-owned system credentials must never be upgraded — scope
|
||||||
|
# changes here would leak across every user that shares them.
|
||||||
|
if is_system_credential(credential_id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="System credentials cannot be upgraded",
|
||||||
|
)
|
||||||
|
|
||||||
|
existing = await creds_manager.store.get_creds_by_id(user_id, credential_id)
|
||||||
|
if not existing:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Credential to upgrade not found",
|
||||||
|
)
|
||||||
|
if not isinstance(existing, OAuth2Credentials):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Only OAuth2 credentials can be upgraded",
|
||||||
|
)
|
||||||
|
if not provider_matches(existing.provider, provider.value):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Credential provider does not match the requested provider",
|
||||||
|
)
|
||||||
|
if existing.is_managed:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Managed credentials cannot be upgraded",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Google handles scope merging via include_granted_scopes; others need
|
||||||
|
# the union of existing + new scopes in the login URL.
|
||||||
|
if provider != ProviderName.GOOGLE:
|
||||||
|
requested_scopes = list(set(requested_scopes) | set(existing.scopes))
|
||||||
|
|
||||||
|
return requested_scopes
|
||||||
|
|
||||||
|
|
||||||
|
async def _merge_or_create_credential(
|
||||||
|
user_id: str,
|
||||||
|
provider: ProviderName,
|
||||||
|
credentials: OAuth2Credentials,
|
||||||
|
credential_id: str | None,
|
||||||
|
) -> OAuth2Credentials:
|
||||||
|
"""Either upgrade an existing credential or create a new one.
|
||||||
|
|
||||||
|
When *credential_id* is set (explicit upgrade), merges scopes and updates
|
||||||
|
the existing credential. Otherwise, checks for an implicit merge (same
|
||||||
|
provider + username) before falling back to creating a new credential.
|
||||||
|
"""
|
||||||
|
if credential_id:
|
||||||
|
return await _upgrade_existing_credential(user_id, credential_id, credentials)
|
||||||
|
|
||||||
|
# Implicit merge: check for existing credential with same provider+username.
|
||||||
|
# Skip managed/system credentials and require a non-None username on both
|
||||||
|
# sides so we never accidentally merge unrelated credentials.
|
||||||
|
if credentials.username is None:
|
||||||
|
await creds_manager.create(user_id, credentials)
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
existing_creds = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||||
|
matching = next(
|
||||||
|
(
|
||||||
|
c
|
||||||
|
for c in existing_creds
|
||||||
|
if isinstance(c, OAuth2Credentials)
|
||||||
|
and not c.is_managed
|
||||||
|
and not is_system_credential(c.id)
|
||||||
|
and c.username is not None
|
||||||
|
and c.username == credentials.username
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if matching:
|
||||||
|
# Only merge into the existing credential when the new token
|
||||||
|
# already covers every scope we're about to advertise on it.
|
||||||
|
# Without this guard we'd overwrite ``matching.access_token`` with
|
||||||
|
# a narrower token while storing a wider ``scopes`` list — the
|
||||||
|
# record would claim authorizations the token does not grant, and
|
||||||
|
# blocks using the lost scopes would fail with opaque 401/403s
|
||||||
|
# until the user hits re-auth. On a narrowing login, keep the
|
||||||
|
# two credentials separate instead.
|
||||||
|
if set(credentials.scopes).issuperset(set(matching.scopes)):
|
||||||
|
return await _upgrade_existing_credential(user_id, matching.id, credentials)
|
||||||
|
|
||||||
|
await creds_manager.create(user_id, credentials)
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
|
||||||
|
async def _upgrade_existing_credential(
|
||||||
|
user_id: str,
|
||||||
|
existing_cred_id: str,
|
||||||
|
new_credentials: OAuth2Credentials,
|
||||||
|
) -> OAuth2Credentials:
|
||||||
|
"""Merge scopes from *new_credentials* into an existing credential."""
|
||||||
|
# Defense-in-depth: re-check system and provider invariants right before
|
||||||
|
# the write. The login-time check in `_prepare_scope_upgrade` can go stale
|
||||||
|
# by the time the callback runs, and the implicit-merge path bypasses
|
||||||
|
# login-time validation entirely, so every write-path must enforce these
|
||||||
|
# on its own.
|
||||||
|
if is_system_credential(existing_cred_id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="System credentials cannot be upgraded",
|
||||||
|
)
|
||||||
|
existing = await creds_manager.store.get_creds_by_id(user_id, existing_cred_id)
|
||||||
|
if not existing or not isinstance(existing, OAuth2Credentials):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Credential to upgrade not found",
|
||||||
|
)
|
||||||
|
if existing.is_managed:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Managed credentials cannot be upgraded",
|
||||||
|
)
|
||||||
|
if not provider_matches(existing.provider, new_credentials.provider):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Credential provider does not match the requested provider",
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
existing.username
|
||||||
|
and new_credentials.username
|
||||||
|
and existing.username != new_credentials.username
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Username mismatch: authenticated as a different user",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Operate on a copy so the caller's ``new_credentials`` object is not
|
||||||
|
# mutated out from under them. Every caller today immediately discards
|
||||||
|
# or replaces its reference, but the implicit-merge path in
|
||||||
|
# ``_merge_or_create_credential`` reads ``credentials.scopes`` before
|
||||||
|
# calling into us — a future reader after the call would otherwise
|
||||||
|
# silently see the overwritten values.
|
||||||
|
merged = new_credentials.model_copy(deep=True)
|
||||||
|
merged.id = existing.id
|
||||||
|
merged.title = existing.title
|
||||||
|
merged.scopes = list(set(existing.scopes) | set(new_credentials.scopes))
|
||||||
|
merged.metadata = {
|
||||||
|
**(existing.metadata or {}),
|
||||||
|
**(new_credentials.metadata or {}),
|
||||||
|
}
|
||||||
|
# Preserve the existing refresh_token and username if the incremental
|
||||||
|
# response doesn't carry them. Providers like Google only return a
|
||||||
|
# refresh_token on first authorization — dropping it here would orphan
|
||||||
|
# the credential on the next access-token expiry, forcing the user to
|
||||||
|
# re-auth from scratch. Username is similarly sticky: if we've already
|
||||||
|
# resolved it for this credential, keep it rather than silently
|
||||||
|
# blanking it on an incremental upgrade.
|
||||||
|
if not merged.refresh_token and existing.refresh_token:
|
||||||
|
merged.refresh_token = existing.refresh_token
|
||||||
|
merged.refresh_token_expires_at = existing.refresh_token_expires_at
|
||||||
|
if not merged.username and existing.username:
|
||||||
|
merged.username = existing.username
|
||||||
|
await creds_manager.update(user_id, merged)
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
# --------------------------- UTILITIES ---------------------------- #
|
# --------------------------- UTILITIES ---------------------------- #
|
||||||
|
|
||||||
|
|
||||||
@@ -784,12 +1122,21 @@ def _get_provider_oauth_handler(
|
|||||||
async def get_ayrshare_sso_url(
|
async def get_ayrshare_sso_url(
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
) -> AyrshareSSOResponse:
|
) -> AyrshareSSOResponse:
|
||||||
"""
|
"""Generate a JWT SSO URL so the user can link their social accounts.
|
||||||
Generate an SSO URL for Ayrshare social media integration.
|
|
||||||
|
|
||||||
Returns:
|
The per-user Ayrshare profile key is provisioned and persisted as a
|
||||||
dict: Contains the SSO URL for Ayrshare integration
|
standard ``is_managed=True`` credential by
|
||||||
|
:class:`~backend.integrations.managed_providers.ayrshare.AyrshareManagedProvider`.
|
||||||
|
This endpoint only signs a short-lived JWT pointing at the Ayrshare-
|
||||||
|
hosted social-linking page; all profile lifecycle logic lives with the
|
||||||
|
managed provider.
|
||||||
"""
|
"""
|
||||||
|
if not ayrshare_settings_available():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Ayrshare integration is not configured",
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client = AyrshareClient()
|
client = AyrshareClient()
|
||||||
except MissingConfigError:
|
except MissingConfigError:
|
||||||
@@ -798,66 +1145,63 @@ async def get_ayrshare_sso_url(
|
|||||||
detail="Ayrshare integration is not configured",
|
detail="Ayrshare integration is not configured",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ayrshare profile key is stored in the credentials store
|
# On-demand provisioning: AyrshareManagedProvider opts out of the
|
||||||
# It is generated when creating a new profile, if there is no profile key,
|
# credentials sweep (profile quota is per-user subscription-bound). This
|
||||||
# we create a new profile and store the profile key in the credentials store
|
# endpoint is the only trigger that provisions a profile — one Ayrshare
|
||||||
|
# profile per user who actually opens the connect flow, not one per
|
||||||
user_integrations: UserIntegrations = await get_user_integrations(user_id)
|
# every authenticated user.
|
||||||
profile_key = user_integrations.managed_credentials.ayrshare_profile_key
|
provisioned = await ensure_managed_credential(
|
||||||
|
user_id, creds_manager.store, AyrshareManagedProvider()
|
||||||
if not profile_key:
|
|
||||||
logger.debug(f"Creating new Ayrshare profile for user {user_id}")
|
|
||||||
try:
|
|
||||||
profile = await client.create_profile(
|
|
||||||
title=f"User {user_id}", messaging_active=True
|
|
||||||
)
|
|
||||||
profile_key = profile.profileKey
|
|
||||||
await creds_manager.store.set_ayrshare_profile_key(user_id, profile_key)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error creating Ayrshare profile for user {user_id}: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTP_502_BAD_GATEWAY,
|
|
||||||
detail="Failed to create Ayrshare profile",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug(f"Using existing Ayrshare profile for user {user_id}")
|
|
||||||
|
|
||||||
profile_key_str = (
|
|
||||||
profile_key.get_secret_value()
|
|
||||||
if isinstance(profile_key, SecretStr)
|
|
||||||
else str(profile_key)
|
|
||||||
)
|
)
|
||||||
|
if not provisioned:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_502_BAD_GATEWAY,
|
||||||
|
detail="Failed to provision Ayrshare profile",
|
||||||
|
)
|
||||||
|
|
||||||
|
ayrshare_creds = [
|
||||||
|
c
|
||||||
|
for c in await creds_manager.store.get_creds_by_provider(user_id, "ayrshare")
|
||||||
|
if c.is_managed and isinstance(c, APIKeyCredentials)
|
||||||
|
]
|
||||||
|
if not ayrshare_creds:
|
||||||
|
logger.error(
|
||||||
|
"Ayrshare credential provisioning did not produce a credential "
|
||||||
|
"for user %s",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_502_BAD_GATEWAY,
|
||||||
|
detail="Failed to provision Ayrshare profile",
|
||||||
|
)
|
||||||
|
profile_key_str = ayrshare_creds[0].api_key.get_secret_value()
|
||||||
|
|
||||||
private_key = settings.secrets.ayrshare_jwt_key
|
private_key = settings.secrets.ayrshare_jwt_key
|
||||||
# Ayrshare JWT expiry is 2880 minutes (48 hours)
|
# Ayrshare JWT max lifetime is 2880 minutes (48 h).
|
||||||
max_expiry_minutes = 2880
|
max_expiry_minutes = 2880
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Generating Ayrshare JWT for user {user_id}")
|
|
||||||
jwt_response = await client.generate_jwt(
|
jwt_response = await client.generate_jwt(
|
||||||
private_key=private_key,
|
private_key=private_key,
|
||||||
profile_key=profile_key_str,
|
profile_key=profile_key_str,
|
||||||
|
# `allowed_social` is the set of networks the Ayrshare-hosted
|
||||||
|
# social-linking page will *offer* the user to connect. Blocks
|
||||||
|
# exist for more platforms than are listed here; the list is
|
||||||
|
# deliberately narrower so the rollout can verify each network
|
||||||
|
# end-to-end before widening the user-visible surface. Keep
|
||||||
|
# in sync with tested platforms — extend as each is verified
|
||||||
|
# against the block + Ayrshare's network-specific quirks.
|
||||||
allowed_social=[
|
allowed_social=[
|
||||||
# NOTE: We are enabling platforms one at a time
|
|
||||||
# to speed up the development process
|
|
||||||
# SocialPlatform.FACEBOOK,
|
|
||||||
SocialPlatform.TWITTER,
|
SocialPlatform.TWITTER,
|
||||||
SocialPlatform.LINKEDIN,
|
SocialPlatform.LINKEDIN,
|
||||||
SocialPlatform.INSTAGRAM,
|
SocialPlatform.INSTAGRAM,
|
||||||
SocialPlatform.YOUTUBE,
|
SocialPlatform.YOUTUBE,
|
||||||
# SocialPlatform.REDDIT,
|
|
||||||
# SocialPlatform.TELEGRAM,
|
|
||||||
# SocialPlatform.GOOGLE_MY_BUSINESS,
|
|
||||||
# SocialPlatform.PINTEREST,
|
|
||||||
SocialPlatform.TIKTOK,
|
SocialPlatform.TIKTOK,
|
||||||
# SocialPlatform.BLUESKY,
|
|
||||||
# SocialPlatform.SNAPCHAT,
|
|
||||||
# SocialPlatform.THREADS,
|
|
||||||
],
|
],
|
||||||
expires_in=max_expiry_minutes,
|
expires_in=max_expiry_minutes,
|
||||||
verify=True,
|
verify=True,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
logger.error(f"Error generating Ayrshare JWT for user {user_id}: {e}")
|
logger.error("Error generating Ayrshare JWT for user %s: %s", user_id, exc)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTP_502_BAD_GATEWAY, detail="Failed to generate JWT"
|
status_code=HTTP_502_BAD_GATEWAY, detail="Failed to generate JWT"
|
||||||
)
|
)
|
||||||
@@ -867,20 +1211,37 @@ async def get_ayrshare_sso_url(
|
|||||||
|
|
||||||
|
|
||||||
# === PROVIDER DISCOVERY ENDPOINTS ===
|
# === PROVIDER DISCOVERY ENDPOINTS ===
|
||||||
@router.get("/providers", response_model=List[str])
|
@router.get("/providers", response_model=List[ProviderMetadata])
|
||||||
async def list_providers() -> List[str]:
|
async def list_providers() -> List[ProviderMetadata]:
|
||||||
"""
|
"""
|
||||||
Get a list of all available provider names.
|
Get metadata for every available provider.
|
||||||
|
|
||||||
Returns both statically defined providers (from ProviderName enum)
|
Returns both statically defined providers (from ``ProviderName`` enum) and
|
||||||
and dynamically registered providers (from SDK decorators).
|
dynamically registered providers (from SDK decorators). Each entry includes
|
||||||
|
a ``description`` declared via ``ProviderBuilder.with_description(...)`` in
|
||||||
|
the provider's ``_config.py``.
|
||||||
|
|
||||||
Note: The complete list of provider names is also available as a constant
|
Note: The complete list of provider names is also available as a constant
|
||||||
in the generated TypeScript client via PROVIDER_NAMES.
|
in the generated TypeScript client via PROVIDER_NAMES.
|
||||||
"""
|
"""
|
||||||
# Get all providers at runtime
|
# Ensure all block modules (and therefore every provider's _config.py) are
|
||||||
|
# imported before we read from AutoRegistry. Cached on first call.
|
||||||
|
try:
|
||||||
|
from backend.blocks import load_all_blocks
|
||||||
|
|
||||||
|
load_all_blocks()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load blocks for provider metadata: {e}")
|
||||||
|
|
||||||
all_providers = get_all_provider_names()
|
all_providers = get_all_provider_names()
|
||||||
return all_providers
|
return [
|
||||||
|
ProviderMetadata(
|
||||||
|
name=name,
|
||||||
|
description=get_provider_description(name),
|
||||||
|
supported_auth_types=get_supported_auth_types(name),
|
||||||
|
)
|
||||||
|
for name in all_providers
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@router.get("/providers/system", response_model=List[str])
|
@router.get("/providers/system", response_model=List[str])
|
||||||
|
|||||||
@@ -393,7 +393,7 @@ class TestEnsureManagedCredentials:
|
|||||||
_PROVIDERS.update(saved)
|
_PROVIDERS.update(saved)
|
||||||
_provisioned_users.pop("user-1", None)
|
_provisioned_users.pop("user-1", None)
|
||||||
|
|
||||||
provider.provision.assert_awaited_once_with("user-1")
|
provider.provision.assert_awaited_once_with("user-1", store)
|
||||||
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
|
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -568,3 +568,181 @@ class TestCleanupManagedCredentials:
|
|||||||
_PROVIDERS.update(saved)
|
_PROVIDERS.update(saved)
|
||||||
|
|
||||||
# No exception raised — cleanup failure is swallowed.
|
# No exception raised — cleanup failure is swallowed.
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetPickerToken:
|
||||||
|
"""POST /{provider}/credentials/{cred_id}/picker-token must:
|
||||||
|
1. Return the access token for OAuth2 creds the caller owns.
|
||||||
|
2. 404 for non-owned, non-existent, or wrong-provider creds.
|
||||||
|
3. 400 for non-OAuth2 creds (API key, host-scoped, user/password).
|
||||||
|
4. 404 for SDK default creds (same hardening as get_credential).
|
||||||
|
5. Preserve the `TestGetCredentialReturnsMetaOnly` contract — the
|
||||||
|
existing meta-only endpoint must still strip secrets even after
|
||||||
|
this picker-token endpoint exists."""
|
||||||
|
|
||||||
|
def test_oauth2_owner_gets_access_token(self):
|
||||||
|
# Use a Google cred with a drive.file scope — only picker-eligible
|
||||||
|
# (provider, scope) pairs can mint a token. GitHub-style creds are
|
||||||
|
# explicitly rejected; see `test_non_picker_provider_rejected_as_400`.
|
||||||
|
cred = _make_oauth2_cred(
|
||||||
|
cred_id="cred-gdrive",
|
||||||
|
provider="google",
|
||||||
|
)
|
||||||
|
cred.scopes = ["https://www.googleapis.com/auth/drive.file"]
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.integrations.router.creds_manager"
|
||||||
|
) as mock_mgr:
|
||||||
|
mock_mgr.get = AsyncMock(return_value=cred)
|
||||||
|
resp = client.post("/google/credentials/cred-gdrive/picker-token")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
# The whole point of this endpoint: the access token IS returned here.
|
||||||
|
assert data["access_token"] == "ghp_secret_token"
|
||||||
|
# Only the two declared fields come back — nothing else leaks.
|
||||||
|
assert set(data.keys()) <= {"access_token", "access_token_expires_at"}
|
||||||
|
|
||||||
|
def test_non_picker_provider_rejected_as_400(self):
|
||||||
|
"""Provider allowlist: even with a valid OAuth2 credential, a
|
||||||
|
non-picker provider (GitHub, etc.) cannot mint a picker token.
|
||||||
|
Stops this endpoint from being used as a generic bearer-token
|
||||||
|
extraction path for any stored OAuth cred under the same user."""
|
||||||
|
cred = _make_oauth2_cred(provider="github")
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.integrations.router.creds_manager"
|
||||||
|
) as mock_mgr:
|
||||||
|
mock_mgr.get = AsyncMock(return_value=cred)
|
||||||
|
resp = client.post("/github/credentials/cred-456/picker-token")
|
||||||
|
|
||||||
|
assert resp.status_code == 400
|
||||||
|
assert "not available for provider" in resp.json()["detail"]
|
||||||
|
assert "ghp_secret_token" not in str(resp.json())
|
||||||
|
|
||||||
|
def test_google_oauth_without_drive_scope_rejected(self):
|
||||||
|
"""Scope allowlist: a Google OAuth2 cred that only carries non-picker
|
||||||
|
scopes (e.g. gmail.readonly, calendar) cannot mint a picker token.
|
||||||
|
Forces the frontend to reconnect with a Drive scope before the
|
||||||
|
picker is available."""
|
||||||
|
cred = _make_oauth2_cred(provider="google")
|
||||||
|
cred.scopes = [
|
||||||
|
"https://www.googleapis.com/auth/gmail.readonly",
|
||||||
|
"https://www.googleapis.com/auth/calendar",
|
||||||
|
]
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.integrations.router.creds_manager"
|
||||||
|
) as mock_mgr:
|
||||||
|
mock_mgr.get = AsyncMock(return_value=cred)
|
||||||
|
resp = client.post("/google/credentials/cred-456/picker-token")
|
||||||
|
|
||||||
|
assert resp.status_code == 400
|
||||||
|
assert "picker" in resp.json()["detail"].lower()
|
||||||
|
|
||||||
|
def test_api_key_credential_rejected_as_400(self):
|
||||||
|
cred = _make_api_key_cred()
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.integrations.router.creds_manager"
|
||||||
|
) as mock_mgr:
|
||||||
|
mock_mgr.get = AsyncMock(return_value=cred)
|
||||||
|
resp = client.post("/openai/credentials/cred-123/picker-token")
|
||||||
|
|
||||||
|
assert resp.status_code == 400
|
||||||
|
# API keys must not silently fall through to a 200 response of some
|
||||||
|
# other shape — the client should see a clear shape rejection.
|
||||||
|
body = str(resp.json())
|
||||||
|
assert "sk-secret-key-value" not in body
|
||||||
|
|
||||||
|
def test_user_password_credential_rejected_as_400(self):
|
||||||
|
cred = _make_user_password_cred()
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.integrations.router.creds_manager"
|
||||||
|
) as mock_mgr:
|
||||||
|
mock_mgr.get = AsyncMock(return_value=cred)
|
||||||
|
resp = client.post("/openai/credentials/cred-789/picker-token")
|
||||||
|
|
||||||
|
assert resp.status_code == 400
|
||||||
|
body = str(resp.json())
|
||||||
|
assert "s3cret-pass" not in body
|
||||||
|
assert "admin" not in body
|
||||||
|
|
||||||
|
def test_host_scoped_credential_rejected_as_400(self):
|
||||||
|
cred = _make_host_scoped_cred()
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.integrations.router.creds_manager"
|
||||||
|
) as mock_mgr:
|
||||||
|
mock_mgr.get = AsyncMock(return_value=cred)
|
||||||
|
resp = client.post("/openai/credentials/cred-host/picker-token")
|
||||||
|
|
||||||
|
assert resp.status_code == 400
|
||||||
|
assert "top-secret" not in str(resp.json())
|
||||||
|
|
||||||
|
def test_missing_credential_returns_404(self):
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.integrations.router.creds_manager"
|
||||||
|
) as mock_mgr:
|
||||||
|
mock_mgr.get = AsyncMock(return_value=None)
|
||||||
|
resp = client.post("/github/credentials/nonexistent/picker-token")
|
||||||
|
|
||||||
|
assert resp.status_code == 404
|
||||||
|
assert resp.json()["detail"] == "Credentials not found"
|
||||||
|
|
||||||
|
def test_wrong_provider_returns_404(self):
|
||||||
|
"""Symmetric with get_credential: provider mismatch is a generic
|
||||||
|
404, not a 400, so we don't leak existence of a credential the
|
||||||
|
caller doesn't own on that provider."""
|
||||||
|
cred = _make_oauth2_cred(provider="github")
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.integrations.router.creds_manager"
|
||||||
|
) as mock_mgr:
|
||||||
|
mock_mgr.get = AsyncMock(return_value=cred)
|
||||||
|
resp = client.post("/google/credentials/cred-456/picker-token")
|
||||||
|
|
||||||
|
assert resp.status_code == 404
|
||||||
|
assert resp.json()["detail"] == "Credentials not found"
|
||||||
|
|
||||||
|
def test_sdk_default_returns_404(self):
|
||||||
|
"""SDK defaults are invisible to the user-facing API — picker-token
|
||||||
|
must not mint a token for them either."""
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.integrations.router.creds_manager"
|
||||||
|
) as mock_mgr:
|
||||||
|
mock_mgr.get = AsyncMock()
|
||||||
|
resp = client.post("/openai/credentials/openai-default/picker-token")
|
||||||
|
|
||||||
|
assert resp.status_code == 404
|
||||||
|
mock_mgr.get.assert_not_called()
|
||||||
|
|
||||||
|
def test_oauth2_without_access_token_returns_400(self):
|
||||||
|
"""A stored OAuth2 cred whose access_token is missing can't satisfy
|
||||||
|
a picker init. Surface a clear reconnect instruction rather than
|
||||||
|
returning an empty string."""
|
||||||
|
cred = _make_oauth2_cred()
|
||||||
|
# Simulate a cred that lost its access token
|
||||||
|
object.__setattr__(cred, "access_token", None)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.integrations.router.creds_manager"
|
||||||
|
) as mock_mgr:
|
||||||
|
mock_mgr.get = AsyncMock(return_value=cred)
|
||||||
|
resp = client.post("/github/credentials/cred-456/picker-token")
|
||||||
|
|
||||||
|
assert resp.status_code == 400
|
||||||
|
assert "reconnect" in resp.json()["detail"].lower()
|
||||||
|
|
||||||
|
def test_meta_only_endpoint_still_strips_access_token(self):
|
||||||
|
"""Regression guard for the coexistence contract: the new
|
||||||
|
picker-token endpoint must NOT accidentally leak the token through
|
||||||
|
the meta-only GET endpoint. TestGetCredentialReturnsMetaOnly
|
||||||
|
covers this more broadly; this is a fast sanity check co-located
|
||||||
|
with the new endpoint's tests."""
|
||||||
|
cred = _make_oauth2_cred()
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.integrations.router.creds_manager"
|
||||||
|
) as mock_mgr:
|
||||||
|
mock_mgr.get = AsyncMock(return_value=cred)
|
||||||
|
resp = client.get("/github/credentials/cred-456")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert "access_token" not in body
|
||||||
|
assert "refresh_token" not in body
|
||||||
|
assert "ghp_secret_token" not in str(body)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import prisma.models
|
|||||||
|
|
||||||
import backend.api.features.library.model as library_model
|
import backend.api.features.library.model as library_model
|
||||||
import backend.data.graph as graph_db
|
import backend.data.graph as graph_db
|
||||||
|
from backend.api.features.library.db import _fetch_schedule_info
|
||||||
from backend.data.graph import GraphModel, GraphSettings
|
from backend.data.graph import GraphModel, GraphSettings
|
||||||
from backend.data.includes import library_agent_include
|
from backend.data.includes import library_agent_include
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
@@ -117,4 +118,5 @@ async def add_graph_to_library(
|
|||||||
f"for store listing version #{store_listing_version_id} "
|
f"for store listing version #{store_listing_version_id} "
|
||||||
f"to library for user #{user_id}"
|
f"to library for user #{user_id}"
|
||||||
)
|
)
|
||||||
return library_model.LibraryAgent.from_db(added_agent)
|
schedule_info = await _fetch_schedule_info(user_id, graph_id=graph_model.id)
|
||||||
|
return library_model.LibraryAgent.from_db(added_agent, schedule_info=schedule_info)
|
||||||
|
|||||||
@@ -21,13 +21,17 @@ async def test_add_graph_to_library_create_new_agent() -> None:
|
|||||||
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
||||||
return_value=converted_agent,
|
return_value=converted_agent,
|
||||||
) as mock_from_db,
|
) as mock_from_db,
|
||||||
|
patch(
|
||||||
|
"backend.api.features.library._add_to_library._fetch_schedule_info",
|
||||||
|
new=AsyncMock(return_value={}),
|
||||||
|
),
|
||||||
):
|
):
|
||||||
mock_prisma.return_value.create = AsyncMock(return_value=created_agent)
|
mock_prisma.return_value.create = AsyncMock(return_value=created_agent)
|
||||||
|
|
||||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||||
|
|
||||||
assert result is converted_agent
|
assert result is converted_agent
|
||||||
mock_from_db.assert_called_once_with(created_agent)
|
mock_from_db.assert_called_once_with(created_agent, schedule_info={})
|
||||||
# Verify create was called with correct data
|
# Verify create was called with correct data
|
||||||
create_call = mock_prisma.return_value.create.call_args
|
create_call = mock_prisma.return_value.create.call_args
|
||||||
create_data = create_call.kwargs["data"]
|
create_data = create_call.kwargs["data"]
|
||||||
@@ -54,6 +58,10 @@ async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
|
|||||||
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
||||||
return_value=converted_agent,
|
return_value=converted_agent,
|
||||||
) as mock_from_db,
|
) as mock_from_db,
|
||||||
|
patch(
|
||||||
|
"backend.api.features.library._add_to_library._fetch_schedule_info",
|
||||||
|
new=AsyncMock(return_value={}),
|
||||||
|
),
|
||||||
):
|
):
|
||||||
mock_prisma.return_value.create = AsyncMock(
|
mock_prisma.return_value.create = AsyncMock(
|
||||||
side_effect=prisma.errors.UniqueViolationError(
|
side_effect=prisma.errors.UniqueViolationError(
|
||||||
@@ -65,7 +73,7 @@ async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
|
|||||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||||
|
|
||||||
assert result is converted_agent
|
assert result is converted_agent
|
||||||
mock_from_db.assert_called_once_with(updated_agent)
|
mock_from_db.assert_called_once_with(updated_agent, schedule_info={})
|
||||||
# Verify update was called with correct where and data
|
# Verify update was called with correct where and data
|
||||||
update_call = mock_prisma.return_value.update.call_args
|
update_call = mock_prisma.return_value.update.call_args
|
||||||
assert update_call.kwargs["where"] == {
|
assert update_call.kwargs["where"] == {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
@@ -43,6 +44,65 @@ config = Config()
|
|||||||
integration_creds_manager = IntegrationCredentialsManager()
|
integration_creds_manager = IntegrationCredentialsManager()
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_execution_counts(user_id: str, graph_ids: list[str]) -> dict[str, int]:
|
||||||
|
"""Fetch execution counts per graph in a single batched query."""
|
||||||
|
if not graph_ids:
|
||||||
|
return {}
|
||||||
|
rows = await prisma.models.AgentGraphExecution.prisma().group_by(
|
||||||
|
by=["agentGraphId"],
|
||||||
|
where={
|
||||||
|
"userId": user_id,
|
||||||
|
"agentGraphId": {"in": graph_ids},
|
||||||
|
"isDeleted": False,
|
||||||
|
},
|
||||||
|
count=True,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
row["agentGraphId"]: int((row.get("_count") or {}).get("_all") or 0)
|
||||||
|
for row in rows
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_schedule_info(
|
||||||
|
user_id: str, graph_id: Optional[str] = None
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Fetch a map of graph_id → earliest next_run_time ISO string.
|
||||||
|
|
||||||
|
When `graph_id` is provided, the scheduler query is narrowed to that graph,
|
||||||
|
which is cheaper for single-agent lookups (detail page, post-update, etc.).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
scheduler_client = get_scheduler_client()
|
||||||
|
schedules = await scheduler_client.get_execution_schedules(
|
||||||
|
graph_id=graph_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
earliest: dict[str, tuple[datetime, str]] = {}
|
||||||
|
for s in schedules:
|
||||||
|
parsed = _parse_iso_datetime(s.next_run_time)
|
||||||
|
if parsed is None:
|
||||||
|
continue
|
||||||
|
current = earliest.get(s.graph_id)
|
||||||
|
if current is None or parsed < current[0]:
|
||||||
|
earliest[s.graph_id] = (parsed, s.next_run_time)
|
||||||
|
return {graph_id: iso for graph_id, (_, iso) in earliest.items()}
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to fetch schedules for library agents", exc_info=True)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_iso_datetime(value: str) -> Optional[datetime]:
|
||||||
|
"""Parse an ISO 8601 datetime, tolerating `Z` and naive forms (assumed UTC)."""
|
||||||
|
try:
|
||||||
|
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("Failed to parse schedule next_run_time: %s", value)
|
||||||
|
return None
|
||||||
|
if parsed.tzinfo is None:
|
||||||
|
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
async def list_library_agents(
|
async def list_library_agents(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
search_term: Optional[str] = None,
|
search_term: Optional[str] = None,
|
||||||
@@ -137,12 +197,22 @@ async def list_library_agents(
|
|||||||
|
|
||||||
logger.debug(f"Retrieved {len(library_agents)} library agents for user #{user_id}")
|
logger.debug(f"Retrieved {len(library_agents)} library agents for user #{user_id}")
|
||||||
|
|
||||||
|
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
|
||||||
|
execution_counts, schedule_info = await asyncio.gather(
|
||||||
|
_fetch_execution_counts(user_id, graph_ids),
|
||||||
|
_fetch_schedule_info(user_id),
|
||||||
|
)
|
||||||
|
|
||||||
# Only pass valid agents to the response
|
# Only pass valid agents to the response
|
||||||
valid_library_agents: list[library_model.LibraryAgent] = []
|
valid_library_agents: list[library_model.LibraryAgent] = []
|
||||||
|
|
||||||
for agent in library_agents:
|
for agent in library_agents:
|
||||||
try:
|
try:
|
||||||
library_agent = library_model.LibraryAgent.from_db(agent)
|
library_agent = library_model.LibraryAgent.from_db(
|
||||||
|
agent,
|
||||||
|
execution_count_override=execution_counts.get(agent.agentGraphId),
|
||||||
|
schedule_info=schedule_info,
|
||||||
|
)
|
||||||
valid_library_agents.append(library_agent)
|
valid_library_agents.append(library_agent)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Skip this agent if there was an error
|
# Skip this agent if there was an error
|
||||||
@@ -214,12 +284,22 @@ async def list_favorite_library_agents(
|
|||||||
f"Retrieved {len(library_agents)} favorite library agents for user #{user_id}"
|
f"Retrieved {len(library_agents)} favorite library agents for user #{user_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
|
||||||
|
execution_counts, schedule_info = await asyncio.gather(
|
||||||
|
_fetch_execution_counts(user_id, graph_ids),
|
||||||
|
_fetch_schedule_info(user_id),
|
||||||
|
)
|
||||||
|
|
||||||
# Only pass valid agents to the response
|
# Only pass valid agents to the response
|
||||||
valid_library_agents: list[library_model.LibraryAgent] = []
|
valid_library_agents: list[library_model.LibraryAgent] = []
|
||||||
|
|
||||||
for agent in library_agents:
|
for agent in library_agents:
|
||||||
try:
|
try:
|
||||||
library_agent = library_model.LibraryAgent.from_db(agent)
|
library_agent = library_model.LibraryAgent.from_db(
|
||||||
|
agent,
|
||||||
|
execution_count_override=execution_counts.get(agent.agentGraphId),
|
||||||
|
schedule_info=schedule_info,
|
||||||
|
)
|
||||||
valid_library_agents.append(library_agent)
|
valid_library_agents.append(library_agent)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Skip this agent if there was an error
|
# Skip this agent if there was an error
|
||||||
@@ -285,6 +365,12 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
|||||||
where={"userId": store_listing.owningUserId}
|
where={"userId": store_listing.owningUserId}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
schedule_info = (
|
||||||
|
await _fetch_schedule_info(user_id, graph_id=library_agent.AgentGraph.id)
|
||||||
|
if library_agent.AgentGraph
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
return library_model.LibraryAgent.from_db(
|
return library_model.LibraryAgent.from_db(
|
||||||
library_agent,
|
library_agent,
|
||||||
sub_graphs=(
|
sub_graphs=(
|
||||||
@@ -294,6 +380,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
|||||||
),
|
),
|
||||||
store_listing=store_listing,
|
store_listing=store_listing,
|
||||||
profile=profile,
|
profile=profile,
|
||||||
|
schedule_info=schedule_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -329,7 +416,10 @@ async def get_library_agent_by_store_version_id(
|
|||||||
},
|
},
|
||||||
include=library_agent_include(user_id),
|
include=library_agent_include(user_id),
|
||||||
)
|
)
|
||||||
return library_model.LibraryAgent.from_db(agent) if agent else None
|
if not agent:
|
||||||
|
return None
|
||||||
|
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId)
|
||||||
|
return library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
|
||||||
|
|
||||||
|
|
||||||
async def get_library_agent_by_graph_id(
|
async def get_library_agent_by_graph_id(
|
||||||
@@ -358,7 +448,10 @@ async def get_library_agent_by_graph_id(
|
|||||||
assert agent.AgentGraph # make type checker happy
|
assert agent.AgentGraph # make type checker happy
|
||||||
# Include sub-graphs so we can make a full credentials input schema
|
# Include sub-graphs so we can make a full credentials input schema
|
||||||
sub_graphs = await graph_db.get_sub_graphs(agent.AgentGraph)
|
sub_graphs = await graph_db.get_sub_graphs(agent.AgentGraph)
|
||||||
return library_model.LibraryAgent.from_db(agent, sub_graphs=sub_graphs)
|
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId)
|
||||||
|
return library_model.LibraryAgent.from_db(
|
||||||
|
agent, sub_graphs=sub_graphs, schedule_info=schedule_info
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def add_generated_agent_image(
|
async def add_generated_agent_image(
|
||||||
@@ -500,7 +593,11 @@ async def create_library_agent(
|
|||||||
for agent, graph in zip(library_agents, graph_entries):
|
for agent, graph in zip(library_agents, graph_entries):
|
||||||
asyncio.create_task(add_generated_agent_image(graph, user_id, agent.id))
|
asyncio.create_task(add_generated_agent_image(graph, user_id, agent.id))
|
||||||
|
|
||||||
return [library_model.LibraryAgent.from_db(agent) for agent in library_agents]
|
schedule_info = await _fetch_schedule_info(user_id)
|
||||||
|
return [
|
||||||
|
library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
|
||||||
|
for agent in library_agents
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
async def update_agent_version_in_library(
|
async def update_agent_version_in_library(
|
||||||
@@ -562,7 +659,8 @@ async def update_agent_version_in_library(
|
|||||||
f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}"
|
f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return library_model.LibraryAgent.from_db(lib)
|
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent_graph_id)
|
||||||
|
return library_model.LibraryAgent.from_db(lib, schedule_info=schedule_info)
|
||||||
|
|
||||||
|
|
||||||
async def create_graph_in_library(
|
async def create_graph_in_library(
|
||||||
@@ -645,6 +743,7 @@ async def update_library_agent_version_and_settings(
|
|||||||
graph=agent_graph,
|
graph=agent_graph,
|
||||||
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
||||||
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
||||||
|
builder_chat_session_id=library.settings.builder_chat_session_id,
|
||||||
)
|
)
|
||||||
if updated_settings != library.settings:
|
if updated_settings != library.settings:
|
||||||
library = await update_library_agent(
|
library = await update_library_agent(
|
||||||
@@ -1467,7 +1566,11 @@ async def bulk_move_agents_to_folder(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return [library_model.LibraryAgent.from_db(agent) for agent in agents]
|
schedule_info = await _fetch_schedule_info(user_id)
|
||||||
|
return [
|
||||||
|
library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
|
||||||
|
for agent in agents
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def collect_tree_ids(
|
def collect_tree_ids(
|
||||||
@@ -1701,7 +1804,7 @@ async def create_preset_from_graph_execution(
|
|||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
f"Graph #{graph_execution.graph_id} not found or accessible"
|
f"Graph #{graph_execution.graph_id} not found or accessible"
|
||||||
)
|
)
|
||||||
elif len(graph.aggregate_credentials_inputs()) > 0:
|
elif len(graph.regular_credentials_inputs) > 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Graph execution #{graph_exec_id} can't be turned into a preset "
|
f"Graph execution #{graph_exec_id} can't be turned into a preset "
|
||||||
"because it was run before this feature existed "
|
"because it was run before this feature existed "
|
||||||
|
|||||||
@@ -65,6 +65,11 @@ async def test_get_library_agents(mocker):
|
|||||||
)
|
)
|
||||||
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.library.db._fetch_execution_counts",
|
||||||
|
new=mocker.AsyncMock(return_value={}),
|
||||||
|
)
|
||||||
|
|
||||||
# Call function
|
# Call function
|
||||||
result = await db.list_library_agents("test-user")
|
result = await db.list_library_agents("test-user")
|
||||||
|
|
||||||
@@ -353,3 +358,136 @@ async def test_create_library_agent_uses_upsert():
|
|||||||
# Verify update branch restores soft-deleted/archived agents
|
# Verify update branch restores soft-deleted/archived agents
|
||||||
assert data["update"]["isDeleted"] is False
|
assert data["update"]["isDeleted"] is False
|
||||||
assert data["update"]["isArchived"] is False
|
assert data["update"]["isArchived"] is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_favorite_library_agents(mocker):
|
||||||
|
mock_library_agents = [
|
||||||
|
prisma.models.LibraryAgent(
|
||||||
|
id="fav1",
|
||||||
|
userId="test-user",
|
||||||
|
agentGraphId="agent-fav",
|
||||||
|
settings="{}", # type: ignore
|
||||||
|
agentGraphVersion=1,
|
||||||
|
isCreatedByUser=False,
|
||||||
|
isDeleted=False,
|
||||||
|
isArchived=False,
|
||||||
|
createdAt=datetime.now(),
|
||||||
|
updatedAt=datetime.now(),
|
||||||
|
isFavorite=True,
|
||||||
|
useGraphIsActiveVersion=True,
|
||||||
|
AgentGraph=prisma.models.AgentGraph(
|
||||||
|
id="agent-fav",
|
||||||
|
version=1,
|
||||||
|
name="Favorite Agent",
|
||||||
|
description="My Favorite",
|
||||||
|
userId="other-user",
|
||||||
|
isActive=True,
|
||||||
|
createdAt=datetime.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||||
|
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||||
|
return_value=mock_library_agents
|
||||||
|
)
|
||||||
|
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.library.db._fetch_execution_counts",
|
||||||
|
new=mocker.AsyncMock(return_value={"agent-fav": 7}),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db.list_favorite_library_agents("test-user")
|
||||||
|
|
||||||
|
assert len(result.agents) == 1
|
||||||
|
assert result.agents[0].id == "fav1"
|
||||||
|
assert result.agents[0].name == "Favorite Agent"
|
||||||
|
assert result.agents[0].graph_id == "agent-fav"
|
||||||
|
assert result.pagination.total_items == 1
|
||||||
|
assert result.pagination.total_pages == 1
|
||||||
|
assert result.pagination.current_page == 1
|
||||||
|
assert result.pagination.page_size == 50
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_library_agents_skips_failed_agent(mocker):
|
||||||
|
"""Agents that fail parsing should be skipped — covers the except branch."""
|
||||||
|
mock_library_agents = [
|
||||||
|
prisma.models.LibraryAgent(
|
||||||
|
id="ua-bad",
|
||||||
|
userId="test-user",
|
||||||
|
agentGraphId="agent-bad",
|
||||||
|
settings="{}", # type: ignore
|
||||||
|
agentGraphVersion=1,
|
||||||
|
isCreatedByUser=False,
|
||||||
|
isDeleted=False,
|
||||||
|
isArchived=False,
|
||||||
|
createdAt=datetime.now(),
|
||||||
|
updatedAt=datetime.now(),
|
||||||
|
isFavorite=False,
|
||||||
|
useGraphIsActiveVersion=True,
|
||||||
|
AgentGraph=prisma.models.AgentGraph(
|
||||||
|
id="agent-bad",
|
||||||
|
version=1,
|
||||||
|
name="Bad Agent",
|
||||||
|
description="",
|
||||||
|
userId="other-user",
|
||||||
|
isActive=True,
|
||||||
|
createdAt=datetime.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||||
|
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||||
|
return_value=mock_library_agents
|
||||||
|
)
|
||||||
|
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.library.db._fetch_execution_counts",
|
||||||
|
new=mocker.AsyncMock(return_value={}),
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.library.model.LibraryAgent.from_db",
|
||||||
|
side_effect=Exception("parse error"),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db.list_library_agents("test-user")
|
||||||
|
|
||||||
|
assert len(result.agents) == 0
|
||||||
|
assert result.pagination.total_items == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_execution_counts_empty_graph_ids():
|
||||||
|
result = await db._fetch_execution_counts("user-1", [])
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_execution_counts_uses_group_by(mocker):
|
||||||
|
mock_prisma = mocker.patch("prisma.models.AgentGraphExecution.prisma")
|
||||||
|
mock_prisma.return_value.group_by = mocker.AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
{"agentGraphId": "graph-1", "_count": {"_all": 5}},
|
||||||
|
{"agentGraphId": "graph-2", "_count": {"_all": 2}},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db._fetch_execution_counts(
|
||||||
|
"user-1", ["graph-1", "graph-2", "graph-3"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"graph-1": 5, "graph-2": 2}
|
||||||
|
mock_prisma.return_value.group_by.assert_called_once_with(
|
||||||
|
by=["agentGraphId"],
|
||||||
|
where={
|
||||||
|
"userId": "user-1",
|
||||||
|
"agentGraphId": {"in": ["graph-1", "graph-2", "graph-3"]},
|
||||||
|
"isDeleted": False,
|
||||||
|
},
|
||||||
|
count=True,
|
||||||
|
)
|
||||||
|
|||||||
@@ -214,6 +214,14 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
folder_name: str | None = None # Denormalized for display
|
folder_name: str | None = None # Denormalized for display
|
||||||
|
|
||||||
recommended_schedule_cron: str | None = None
|
recommended_schedule_cron: str | None = None
|
||||||
|
is_scheduled: bool = pydantic.Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether this agent has active execution schedules",
|
||||||
|
)
|
||||||
|
next_scheduled_run: str | None = pydantic.Field(
|
||||||
|
default=None,
|
||||||
|
description="ISO 8601 timestamp of the next scheduled run, if any",
|
||||||
|
)
|
||||||
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
||||||
marketplace_listing: Optional["MarketplaceListing"] = None
|
marketplace_listing: Optional["MarketplaceListing"] = None
|
||||||
|
|
||||||
@@ -223,6 +231,8 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
sub_graphs: Optional[list[prisma.models.AgentGraph]] = None,
|
sub_graphs: Optional[list[prisma.models.AgentGraph]] = None,
|
||||||
store_listing: Optional[prisma.models.StoreListing] = None,
|
store_listing: Optional[prisma.models.StoreListing] = None,
|
||||||
profile: Optional[prisma.models.Profile] = None,
|
profile: Optional[prisma.models.Profile] = None,
|
||||||
|
execution_count_override: Optional[int] = None,
|
||||||
|
schedule_info: Optional[dict[str, str]] = None,
|
||||||
) -> "LibraryAgent":
|
) -> "LibraryAgent":
|
||||||
"""
|
"""
|
||||||
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
|
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
|
||||||
@@ -258,10 +268,14 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
status = status_result.status
|
status = status_result.status
|
||||||
new_output = status_result.new_output
|
new_output = status_result.new_output
|
||||||
|
|
||||||
execution_count = len(executions)
|
execution_count = (
|
||||||
|
execution_count_override
|
||||||
|
if execution_count_override is not None
|
||||||
|
else len(executions)
|
||||||
|
)
|
||||||
success_rate: float | None = None
|
success_rate: float | None = None
|
||||||
avg_correctness_score: float | None = None
|
avg_correctness_score: float | None = None
|
||||||
if execution_count > 0:
|
if executions and execution_count > 0:
|
||||||
success_count = sum(
|
success_count = sum(
|
||||||
1
|
1
|
||||||
for e in executions
|
for e in executions
|
||||||
@@ -354,6 +368,10 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
folder_id=agent.folderId,
|
folder_id=agent.folderId,
|
||||||
folder_name=agent.Folder.name if agent.Folder else None,
|
folder_name=agent.Folder.name if agent.Folder else None,
|
||||||
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
||||||
|
is_scheduled=bool(schedule_info and agent.agentGraphId in schedule_info),
|
||||||
|
next_scheduled_run=(
|
||||||
|
schedule_info.get(agent.agentGraphId) if schedule_info else None
|
||||||
|
),
|
||||||
settings=_parse_settings(agent.settings),
|
settings=_parse_settings(agent.settings),
|
||||||
marketplace_listing=marketplace_listing_data,
|
marketplace_listing=marketplace_listing_data,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,11 +1,66 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
import prisma.enums
|
||||||
import prisma.models
|
import prisma.models
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from . import model as library_model
|
from . import model as library_model
|
||||||
|
|
||||||
|
|
||||||
|
def _make_library_agent(
|
||||||
|
*,
|
||||||
|
graph_id: str = "g1",
|
||||||
|
executions: list | None = None,
|
||||||
|
) -> prisma.models.LibraryAgent:
|
||||||
|
return prisma.models.LibraryAgent(
|
||||||
|
id="la1",
|
||||||
|
userId="u1",
|
||||||
|
agentGraphId=graph_id,
|
||||||
|
settings="{}", # type: ignore
|
||||||
|
agentGraphVersion=1,
|
||||||
|
isCreatedByUser=True,
|
||||||
|
isDeleted=False,
|
||||||
|
isArchived=False,
|
||||||
|
createdAt=datetime.datetime.now(),
|
||||||
|
updatedAt=datetime.datetime.now(),
|
||||||
|
isFavorite=False,
|
||||||
|
useGraphIsActiveVersion=True,
|
||||||
|
AgentGraph=prisma.models.AgentGraph(
|
||||||
|
id=graph_id,
|
||||||
|
version=1,
|
||||||
|
name="Agent",
|
||||||
|
description="Desc",
|
||||||
|
userId="u1",
|
||||||
|
isActive=True,
|
||||||
|
createdAt=datetime.datetime.now(),
|
||||||
|
Executions=executions,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_db_execution_count_override_covers_success_rate():
|
||||||
|
"""Covers execution_count_override is not None branch and executions/count > 0 block."""
|
||||||
|
now = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
exec1 = prisma.models.AgentGraphExecution(
|
||||||
|
id="exec-1",
|
||||||
|
agentGraphId="g1",
|
||||||
|
agentGraphVersion=1,
|
||||||
|
userId="u1",
|
||||||
|
executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED,
|
||||||
|
createdAt=now,
|
||||||
|
updatedAt=now,
|
||||||
|
isDeleted=False,
|
||||||
|
isShared=False,
|
||||||
|
)
|
||||||
|
agent = _make_library_agent(executions=[exec1])
|
||||||
|
|
||||||
|
result = library_model.LibraryAgent.from_db(agent, execution_count_override=1)
|
||||||
|
|
||||||
|
assert result.execution_count == 1
|
||||||
|
assert result.success_rate is not None
|
||||||
|
assert result.success_rate == 100.0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_agent_preset_from_db(test_user_id: str):
|
async def test_agent_preset_from_db(test_user_id: str):
|
||||||
# Create mock DB agent
|
# Create mock DB agent
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Platform bot linking — user-facing REST routes."""
|
||||||
@@ -0,0 +1,158 @@
|
|||||||
|
"""User-facing platform_linking REST routes (JWT auth)."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from autogpt_libs import auth
|
||||||
|
from fastapi import APIRouter, HTTPException, Path, Security
|
||||||
|
|
||||||
|
from backend.data.db_accessors import platform_linking_db
|
||||||
|
from backend.platform_linking.models import (
|
||||||
|
ConfirmLinkResponse,
|
||||||
|
ConfirmUserLinkResponse,
|
||||||
|
DeleteLinkResponse,
|
||||||
|
LinkTokenInfoResponse,
|
||||||
|
PlatformLinkInfo,
|
||||||
|
PlatformUserLinkInfo,
|
||||||
|
)
|
||||||
|
from backend.util.exceptions import (
|
||||||
|
LinkAlreadyExistsError,
|
||||||
|
LinkFlowMismatchError,
|
||||||
|
LinkTokenExpiredError,
|
||||||
|
NotAuthorizedError,
|
||||||
|
NotFoundError,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
TokenPath = Annotated[
|
||||||
|
str,
|
||||||
|
Path(max_length=64, pattern=r"^[A-Za-z0-9_-]+$"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _translate(exc: Exception) -> HTTPException:
|
||||||
|
if isinstance(exc, NotFoundError):
|
||||||
|
return HTTPException(status_code=404, detail=str(exc))
|
||||||
|
if isinstance(exc, NotAuthorizedError):
|
||||||
|
return HTTPException(status_code=403, detail=str(exc))
|
||||||
|
if isinstance(exc, LinkAlreadyExistsError):
|
||||||
|
return HTTPException(status_code=409, detail=str(exc))
|
||||||
|
if isinstance(exc, LinkTokenExpiredError):
|
||||||
|
return HTTPException(status_code=410, detail=str(exc))
|
||||||
|
if isinstance(exc, LinkFlowMismatchError):
|
||||||
|
return HTTPException(status_code=400, detail=str(exc))
|
||||||
|
return HTTPException(status_code=500, detail="Internal error.")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/tokens/{token}/info",
|
||||||
|
response_model=LinkTokenInfoResponse,
|
||||||
|
dependencies=[Security(auth.requires_user)],
|
||||||
|
summary="Get display info for a link token",
|
||||||
|
)
|
||||||
|
async def get_link_token_info_route(token: TokenPath) -> LinkTokenInfoResponse:
|
||||||
|
try:
|
||||||
|
return await platform_linking_db().get_link_token_info(token)
|
||||||
|
except (NotFoundError, LinkTokenExpiredError) as exc:
|
||||||
|
raise _translate(exc) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/tokens/{token}/confirm",
|
||||||
|
response_model=ConfirmLinkResponse,
|
||||||
|
dependencies=[Security(auth.requires_user)],
|
||||||
|
summary="Confirm a SERVER link token (user must be authenticated)",
|
||||||
|
)
|
||||||
|
async def confirm_link_token(
|
||||||
|
token: TokenPath,
|
||||||
|
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||||
|
) -> ConfirmLinkResponse:
|
||||||
|
try:
|
||||||
|
return await platform_linking_db().confirm_server_link(token, user_id)
|
||||||
|
except (
|
||||||
|
NotFoundError,
|
||||||
|
LinkFlowMismatchError,
|
||||||
|
LinkTokenExpiredError,
|
||||||
|
LinkAlreadyExistsError,
|
||||||
|
) as exc:
|
||||||
|
raise _translate(exc) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/user-tokens/{token}/confirm",
|
||||||
|
response_model=ConfirmUserLinkResponse,
|
||||||
|
dependencies=[Security(auth.requires_user)],
|
||||||
|
summary="Confirm a USER link token (user must be authenticated)",
|
||||||
|
)
|
||||||
|
async def confirm_user_link_token(
|
||||||
|
token: TokenPath,
|
||||||
|
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||||
|
) -> ConfirmUserLinkResponse:
|
||||||
|
try:
|
||||||
|
return await platform_linking_db().confirm_user_link(token, user_id)
|
||||||
|
except (
|
||||||
|
NotFoundError,
|
||||||
|
LinkFlowMismatchError,
|
||||||
|
LinkTokenExpiredError,
|
||||||
|
LinkAlreadyExistsError,
|
||||||
|
) as exc:
|
||||||
|
raise _translate(exc) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/links",
|
||||||
|
response_model=list[PlatformLinkInfo],
|
||||||
|
dependencies=[Security(auth.requires_user)],
|
||||||
|
summary="List all platform servers linked to the authenticated user",
|
||||||
|
)
|
||||||
|
async def list_my_links(
|
||||||
|
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||||
|
) -> list[PlatformLinkInfo]:
|
||||||
|
return await platform_linking_db().list_server_links(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/user-links",
|
||||||
|
response_model=list[PlatformUserLinkInfo],
|
||||||
|
dependencies=[Security(auth.requires_user)],
|
||||||
|
summary="List all DM links for the authenticated user",
|
||||||
|
)
|
||||||
|
async def list_my_user_links(
|
||||||
|
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||||
|
) -> list[PlatformUserLinkInfo]:
|
||||||
|
return await platform_linking_db().list_user_links(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/links/{link_id}",
|
||||||
|
response_model=DeleteLinkResponse,
|
||||||
|
dependencies=[Security(auth.requires_user)],
|
||||||
|
summary="Unlink a platform server",
|
||||||
|
)
|
||||||
|
async def delete_link(
|
||||||
|
link_id: str,
|
||||||
|
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||||
|
) -> DeleteLinkResponse:
|
||||||
|
try:
|
||||||
|
return await platform_linking_db().delete_server_link(link_id, user_id)
|
||||||
|
except (NotFoundError, NotAuthorizedError) as exc:
|
||||||
|
raise _translate(exc) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/user-links/{link_id}",
|
||||||
|
response_model=DeleteLinkResponse,
|
||||||
|
dependencies=[Security(auth.requires_user)],
|
||||||
|
summary="Unlink a DM / user link",
|
||||||
|
)
|
||||||
|
async def delete_user_link_route(
|
||||||
|
link_id: str,
|
||||||
|
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||||
|
) -> DeleteLinkResponse:
|
||||||
|
try:
|
||||||
|
return await platform_linking_db().delete_user_link(link_id, user_id)
|
||||||
|
except (NotFoundError, NotAuthorizedError) as exc:
|
||||||
|
raise _translate(exc) from exc
|
||||||
@@ -0,0 +1,264 @@
|
|||||||
|
"""Route tests: domain exceptions → HTTPException status codes."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from backend.util.exceptions import (
|
||||||
|
LinkAlreadyExistsError,
|
||||||
|
LinkFlowMismatchError,
|
||||||
|
LinkTokenExpiredError,
|
||||||
|
NotAuthorizedError,
|
||||||
|
NotFoundError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _db_mock(**method_configs):
|
||||||
|
"""Return a mock of the accessor's return value with the given AsyncMocks."""
|
||||||
|
db = MagicMock()
|
||||||
|
for name, mock in method_configs.items():
|
||||||
|
setattr(db, name, mock)
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenInfoRouteTranslation:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_not_found_maps_to_404(self):
|
||||||
|
from backend.api.features.platform_linking.routes import (
|
||||||
|
get_link_token_info_route,
|
||||||
|
)
|
||||||
|
|
||||||
|
db = _db_mock(
|
||||||
|
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||||
|
return_value=db,
|
||||||
|
):
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
await get_link_token_info_route(token="abc")
|
||||||
|
assert exc.value.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_expired_maps_to_410(self):
|
||||||
|
from backend.api.features.platform_linking.routes import (
|
||||||
|
get_link_token_info_route,
|
||||||
|
)
|
||||||
|
|
||||||
|
db = _db_mock(
|
||||||
|
get_link_token_info=AsyncMock(side_effect=LinkTokenExpiredError("expired"))
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||||
|
return_value=db,
|
||||||
|
):
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
await get_link_token_info_route(token="abc")
|
||||||
|
assert exc.value.status_code == 410
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfirmLinkRouteTranslation:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"exc,expected_status",
|
||||||
|
[
|
||||||
|
(NotFoundError("missing"), 404),
|
||||||
|
(LinkFlowMismatchError("wrong flow"), 400),
|
||||||
|
(LinkTokenExpiredError("expired"), 410),
|
||||||
|
(LinkAlreadyExistsError("already"), 409),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_translation(self, exc: Exception, expected_status: int):
|
||||||
|
from backend.api.features.platform_linking.routes import confirm_link_token
|
||||||
|
|
||||||
|
db = _db_mock(confirm_server_link=AsyncMock(side_effect=exc))
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||||
|
return_value=db,
|
||||||
|
):
|
||||||
|
with pytest.raises(HTTPException) as ctx:
|
||||||
|
await confirm_link_token(token="abc", user_id="u1")
|
||||||
|
assert ctx.value.status_code == expected_status
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfirmUserLinkRouteTranslation:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"exc,expected_status",
|
||||||
|
[
|
||||||
|
(NotFoundError("missing"), 404),
|
||||||
|
(LinkFlowMismatchError("wrong flow"), 400),
|
||||||
|
(LinkTokenExpiredError("expired"), 410),
|
||||||
|
(LinkAlreadyExistsError("already"), 409),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_translation(self, exc: Exception, expected_status: int):
|
||||||
|
from backend.api.features.platform_linking.routes import confirm_user_link_token
|
||||||
|
|
||||||
|
db = _db_mock(confirm_user_link=AsyncMock(side_effect=exc))
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||||
|
return_value=db,
|
||||||
|
):
|
||||||
|
with pytest.raises(HTTPException) as ctx:
|
||||||
|
await confirm_user_link_token(token="abc", user_id="u1")
|
||||||
|
assert ctx.value.status_code == expected_status
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteLinkRouteTranslation:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_not_found_maps_to_404(self):
|
||||||
|
from backend.api.features.platform_linking.routes import delete_link
|
||||||
|
|
||||||
|
db = _db_mock(
|
||||||
|
delete_server_link=AsyncMock(side_effect=NotFoundError("missing"))
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||||
|
return_value=db,
|
||||||
|
):
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
await delete_link(link_id="x", user_id="u1")
|
||||||
|
assert exc.value.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_not_owned_maps_to_403(self):
|
||||||
|
from backend.api.features.platform_linking.routes import delete_link
|
||||||
|
|
||||||
|
db = _db_mock(
|
||||||
|
delete_server_link=AsyncMock(side_effect=NotAuthorizedError("nope"))
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||||
|
return_value=db,
|
||||||
|
):
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
await delete_link(link_id="x", user_id="u1")
|
||||||
|
assert exc.value.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteUserLinkRouteTranslation:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_not_found_maps_to_404(self):
|
||||||
|
from backend.api.features.platform_linking.routes import delete_user_link_route
|
||||||
|
|
||||||
|
db = _db_mock(delete_user_link=AsyncMock(side_effect=NotFoundError("missing")))
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||||
|
return_value=db,
|
||||||
|
):
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
await delete_user_link_route(link_id="x", user_id="u1")
|
||||||
|
assert exc.value.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_not_owned_maps_to_403(self):
|
||||||
|
from backend.api.features.platform_linking.routes import delete_user_link_route
|
||||||
|
|
||||||
|
db = _db_mock(
|
||||||
|
delete_user_link=AsyncMock(side_effect=NotAuthorizedError("nope"))
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||||
|
return_value=db,
|
||||||
|
):
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
await delete_user_link_route(link_id="x", user_id="u1")
|
||||||
|
assert exc.value.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
# ── Adversarial: malformed token path params ──────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdversarialTokenPath:
|
||||||
|
# TokenPath enforces `^[A-Za-z0-9_-]+$` + max_length=64.
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(self):
|
||||||
|
import fastapi
|
||||||
|
from autogpt_libs.auth import get_user_id, requires_user
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
import backend.api.features.platform_linking.routes as routes_mod
|
||||||
|
|
||||||
|
app = fastapi.FastAPI()
|
||||||
|
app.dependency_overrides[requires_user] = lambda: None
|
||||||
|
app.dependency_overrides[get_user_id] = lambda: "caller-user"
|
||||||
|
app.include_router(routes_mod.router, prefix="/api/platform-linking")
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
def test_rejects_token_with_special_chars(self, client):
|
||||||
|
response = client.get("/api/platform-linking/tokens/bad%24token/info")
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_rejects_token_with_path_traversal(self, client):
|
||||||
|
for probe in ("..%2F..", "foo..bar", "foo%2Fbar"):
|
||||||
|
response = client.get(f"/api/platform-linking/tokens/{probe}/info")
|
||||||
|
assert response.status_code in (
|
||||||
|
404,
|
||||||
|
422,
|
||||||
|
), f"path-traversal probe {probe!r} returned {response.status_code}"
|
||||||
|
|
||||||
|
def test_rejects_token_too_long(self, client):
|
||||||
|
long_token = "a" * 65
|
||||||
|
response = client.get(f"/api/platform-linking/tokens/{long_token}/info")
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_accepts_token_at_max_length(self, client):
|
||||||
|
token = "a" * 64
|
||||||
|
db = _db_mock(
|
||||||
|
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||||
|
return_value=db,
|
||||||
|
):
|
||||||
|
response = client.get(f"/api/platform-linking/tokens/{token}/info")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
def test_accepts_urlsafe_b64_token_shape(self, client):
|
||||||
|
db = _db_mock(
|
||||||
|
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||||
|
return_value=db,
|
||||||
|
):
|
||||||
|
response = client.get("/api/platform-linking/tokens/abc-_XYZ123-_abc/info")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
def test_confirm_rejects_malformed_token(self, client):
|
||||||
|
response = client.post("/api/platform-linking/tokens/bad%24token/confirm")
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdversarialDeleteLinkId:
|
||||||
|
"""DELETE link_id has no regex — ensure weird values are handled via
|
||||||
|
NotFoundError (no crash, no cross-user leak)."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(self):
|
||||||
|
import fastapi
|
||||||
|
from autogpt_libs.auth import get_user_id, requires_user
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
import backend.api.features.platform_linking.routes as routes_mod
|
||||||
|
|
||||||
|
app = fastapi.FastAPI()
|
||||||
|
app.dependency_overrides[requires_user] = lambda: None
|
||||||
|
app.dependency_overrides[get_user_id] = lambda: "caller-user"
|
||||||
|
app.include_router(routes_mod.router, prefix="/api/platform-linking")
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
def test_weird_link_id_returns_404(self, client):
|
||||||
|
db = _db_mock(
|
||||||
|
delete_server_link=AsyncMock(side_effect=NotFoundError("missing"))
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||||
|
return_value=db,
|
||||||
|
):
|
||||||
|
for link_id in ("'; DROP TABLE links;--", "../../etc/passwd", ""):
|
||||||
|
response = client.delete(f"/api/platform-linking/links/{link_id}")
|
||||||
|
assert response.status_code in (404, 405)
|
||||||
20
autogpt_platform/backend/backend/api/features/push/model.py
Normal file
20
autogpt_platform/backend/backend/api/features/push/model.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import pydantic
|
||||||
|
|
||||||
|
|
||||||
|
class PushSubscriptionKeys(pydantic.BaseModel):
|
||||||
|
p256dh: str = pydantic.Field(min_length=1, max_length=512)
|
||||||
|
auth: str = pydantic.Field(min_length=1, max_length=512)
|
||||||
|
|
||||||
|
|
||||||
|
class PushSubscribeRequest(pydantic.BaseModel):
|
||||||
|
endpoint: str = pydantic.Field(min_length=1, max_length=2048)
|
||||||
|
keys: PushSubscriptionKeys
|
||||||
|
user_agent: str | None = pydantic.Field(default=None, max_length=512)
|
||||||
|
|
||||||
|
|
||||||
|
class PushUnsubscribeRequest(pydantic.BaseModel):
|
||||||
|
endpoint: str = pydantic.Field(min_length=1, max_length=2048)
|
||||||
|
|
||||||
|
|
||||||
|
class VapidPublicKeyResponse(pydantic.BaseModel):
|
||||||
|
public_key: str
|
||||||
64
autogpt_platform/backend/backend/api/features/push/routes.py
Normal file
64
autogpt_platform/backend/backend/api/features/push/routes.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from autogpt_libs.auth import get_user_id, requires_user
|
||||||
|
from fastapi import APIRouter, HTTPException, Security
|
||||||
|
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
|
||||||
|
|
||||||
|
from backend.api.features.push.model import (
|
||||||
|
PushSubscribeRequest,
|
||||||
|
PushUnsubscribeRequest,
|
||||||
|
VapidPublicKeyResponse,
|
||||||
|
)
|
||||||
|
from backend.data.push_subscription import (
|
||||||
|
delete_push_subscription,
|
||||||
|
upsert_push_subscription,
|
||||||
|
validate_push_endpoint,
|
||||||
|
)
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
_settings = Settings()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/vapid-key",
|
||||||
|
summary="Get VAPID public key for push subscription",
|
||||||
|
)
|
||||||
|
async def get_vapid_public_key() -> VapidPublicKeyResponse:
|
||||||
|
return VapidPublicKeyResponse(public_key=_settings.secrets.vapid_public_key)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/subscribe",
|
||||||
|
summary="Register a push subscription for the current user",
|
||||||
|
status_code=HTTP_204_NO_CONTENT,
|
||||||
|
dependencies=[Security(requires_user)],
|
||||||
|
)
|
||||||
|
async def subscribe_push(
|
||||||
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
body: PushSubscribeRequest,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
await validate_push_endpoint(body.endpoint)
|
||||||
|
await upsert_push_subscription(
|
||||||
|
user_id=user_id,
|
||||||
|
endpoint=body.endpoint,
|
||||||
|
p256dh=body.keys.p256dh,
|
||||||
|
auth=body.keys.auth,
|
||||||
|
user_agent=body.user_agent,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/unsubscribe",
|
||||||
|
summary="Remove a push subscription",
|
||||||
|
status_code=HTTP_204_NO_CONTENT,
|
||||||
|
dependencies=[Security(requires_user)],
|
||||||
|
)
|
||||||
|
async def unsubscribe_push(
|
||||||
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
body: PushUnsubscribeRequest,
|
||||||
|
) -> None:
|
||||||
|
await delete_push_subscription(user_id, body.endpoint)
|
||||||
@@ -0,0 +1,240 @@
|
|||||||
|
"""Tests for push notification routes."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
import fastapi.testclient
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.push.routes import router
|
||||||
|
|
||||||
|
app = fastapi.FastAPI()
|
||||||
|
app.include_router(router)
|
||||||
|
client = fastapi.testclient.TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_app_auth(mock_jwt_user):
|
||||||
|
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||||
|
|
||||||
|
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_vapid_public_key(mocker):
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.secrets.vapid_public_key = "test-vapid-public-key-base64url"
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.push.routes._settings",
|
||||||
|
mock_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/vapid-key")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["public_key"] == "test-vapid-public-key-base64url"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_vapid_public_key_empty(mocker):
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.secrets.vapid_public_key = ""
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.push.routes._settings",
|
||||||
|
mock_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/vapid-key")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["public_key"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_subscribe_push(mocker, test_user_id):
|
||||||
|
mock_upsert = mocker.patch(
|
||||||
|
"backend.api.features.push.routes.upsert_push_subscription",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/subscribe",
|
||||||
|
json={
|
||||||
|
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||||
|
"keys": {
|
||||||
|
"p256dh": "test-p256dh-key",
|
||||||
|
"auth": "test-auth-key",
|
||||||
|
},
|
||||||
|
"user_agent": "Mozilla/5.0 Test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 204
|
||||||
|
mock_upsert.assert_awaited_once_with(
|
||||||
|
user_id=test_user_id,
|
||||||
|
endpoint="https://fcm.googleapis.com/fcm/send/abc123",
|
||||||
|
p256dh="test-p256dh-key",
|
||||||
|
auth="test-auth-key",
|
||||||
|
user_agent="Mozilla/5.0 Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subscribe_push_without_user_agent(mocker, test_user_id):
|
||||||
|
mock_upsert = mocker.patch(
|
||||||
|
"backend.api.features.push.routes.upsert_push_subscription",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/subscribe",
|
||||||
|
json={
|
||||||
|
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||||
|
"keys": {
|
||||||
|
"p256dh": "test-p256dh-key",
|
||||||
|
"auth": "test-auth-key",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 204
|
||||||
|
mock_upsert.assert_awaited_once_with(
|
||||||
|
user_id=test_user_id,
|
||||||
|
endpoint="https://fcm.googleapis.com/fcm/send/abc123",
|
||||||
|
p256dh="test-p256dh-key",
|
||||||
|
auth="test-auth-key",
|
||||||
|
user_agent=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subscribe_push_missing_keys():
|
||||||
|
response = client.post(
|
||||||
|
"/subscribe",
|
||||||
|
json={
|
||||||
|
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_subscribe_push_missing_endpoint():
|
||||||
|
response = client.post(
|
||||||
|
"/subscribe",
|
||||||
|
json={
|
||||||
|
"keys": {
|
||||||
|
"p256dh": "test-p256dh-key",
|
||||||
|
"auth": "test-auth-key",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_subscribe_push_rejects_empty_crypto_keys():
|
||||||
|
response = client.post(
|
||||||
|
"/subscribe",
|
||||||
|
json={
|
||||||
|
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||||
|
"keys": {"p256dh": "", "auth": ""},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_subscribe_push_rejects_oversized_endpoint():
|
||||||
|
response = client.post(
|
||||||
|
"/subscribe",
|
||||||
|
json={
|
||||||
|
"endpoint": "https://fcm.googleapis.com/fcm/send/" + "x" * 3000,
|
||||||
|
"keys": {"p256dh": "k", "auth": "a"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_unsubscribe_push(mocker, test_user_id):
|
||||||
|
mock_delete = mocker.patch(
|
||||||
|
"backend.api.features.push.routes.delete_push_subscription",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/unsubscribe",
|
||||||
|
json={
|
||||||
|
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 204
|
||||||
|
mock_delete.assert_awaited_once_with(
|
||||||
|
test_user_id,
|
||||||
|
"https://fcm.googleapis.com/fcm/send/abc123",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_unsubscribe_push_missing_endpoint():
|
||||||
|
response = client.post(
|
||||||
|
"/unsubscribe",
|
||||||
|
json={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"untrusted_endpoint",
|
||||||
|
[
|
||||||
|
"https://localhost/evil",
|
||||||
|
"https://127.0.0.1/evil",
|
||||||
|
"https://169.254.169.254/latest/meta-data/",
|
||||||
|
"https://internal-service.local/api",
|
||||||
|
"https://attacker.example.com/push",
|
||||||
|
"http://fcm.googleapis.com/fcm/send/abc",
|
||||||
|
"file:///etc/passwd",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_subscribe_push_rejects_untrusted_endpoints(mocker, untrusted_endpoint):
|
||||||
|
mock_upsert = mocker.patch(
|
||||||
|
"backend.api.features.push.routes.upsert_push_subscription",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/subscribe",
|
||||||
|
json={
|
||||||
|
"endpoint": untrusted_endpoint,
|
||||||
|
"keys": {
|
||||||
|
"p256dh": "test-p256dh-key",
|
||||||
|
"auth": "test-auth-key",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
mock_upsert.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
def test_subscribe_push_surfaces_cap_as_400(mocker):
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.push.routes.upsert_push_subscription",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=ValueError("Subscription limit of 20 per user reached"),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/subscribe",
|
||||||
|
json={
|
||||||
|
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||||
|
"keys": {
|
||||||
|
"p256dh": "test-p256dh-key",
|
||||||
|
"auth": "test-auth-key",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "Subscription limit" in response.json()["detail"]
|
||||||
@@ -490,6 +490,9 @@ async def get_store_creators(
|
|||||||
# Build where clause with sanitized inputs
|
# Build where clause with sanitized inputs
|
||||||
where = {}
|
where = {}
|
||||||
|
|
||||||
|
# Only return creators with approved agents
|
||||||
|
where["num_agents"] = {"gt": 0}
|
||||||
|
|
||||||
if featured:
|
if featured:
|
||||||
where["is_featured"] = featured
|
where["is_featured"] = featured
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
import prisma.errors
|
import prisma.errors
|
||||||
@@ -50,8 +51,8 @@ async def test_get_store_agents(mocker):
|
|||||||
|
|
||||||
# Mock prisma calls
|
# Mock prisma calls
|
||||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||||
mock_store_agent.return_value.find_many = mocker.AsyncMock(return_value=mock_agents)
|
mock_store_agent.return_value.find_many = AsyncMock(return_value=mock_agents)
|
||||||
mock_store_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
mock_store_agent.return_value.count = AsyncMock(return_value=1)
|
||||||
|
|
||||||
# Call function
|
# Call function
|
||||||
result = await db.get_store_agents()
|
result = await db.get_store_agents()
|
||||||
@@ -94,7 +95,7 @@ async def test_get_store_agent_details(mocker):
|
|||||||
|
|
||||||
# Mock StoreAgent prisma call
|
# Mock StoreAgent prisma call
|
||||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
mock_store_agent.return_value.find_first = AsyncMock(return_value=mock_agent)
|
||||||
|
|
||||||
# Call function
|
# Call function
|
||||||
result = await db.get_store_agent_details("creator", "test-agent")
|
result = await db.get_store_agent_details("creator", "test-agent")
|
||||||
@@ -133,7 +134,7 @@ async def test_get_store_creator(mocker):
|
|||||||
|
|
||||||
# Mock prisma call
|
# Mock prisma call
|
||||||
mock_creator = mocker.patch("prisma.models.Creator.prisma")
|
mock_creator = mocker.patch("prisma.models.Creator.prisma")
|
||||||
mock_creator.return_value.find_unique = mocker.AsyncMock()
|
mock_creator.return_value.find_unique = AsyncMock()
|
||||||
# Configure the mock to return values that will pass validation
|
# Configure the mock to return values that will pass validation
|
||||||
mock_creator.return_value.find_unique.return_value = mock_creator_data
|
mock_creator.return_value.find_unique.return_value = mock_creator_data
|
||||||
|
|
||||||
@@ -189,7 +190,7 @@ async def test_create_store_submission(mocker):
|
|||||||
notifyOnAgentApproved=True,
|
notifyOnAgentApproved=True,
|
||||||
notifyOnAgentRejected=True,
|
notifyOnAgentRejected=True,
|
||||||
timezone="Europe/Delft",
|
timezone="Europe/Delft",
|
||||||
subscriptionTier=prisma.enums.SubscriptionTier.FREE, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
|
subscriptionTier=prisma.enums.SubscriptionTier.BASIC, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
|
||||||
)
|
)
|
||||||
mock_agent = prisma.models.AgentGraph(
|
mock_agent = prisma.models.AgentGraph(
|
||||||
id="agent-id",
|
id="agent-id",
|
||||||
@@ -236,23 +237,23 @@ async def test_create_store_submission(mocker):
|
|||||||
|
|
||||||
# Mock prisma calls
|
# Mock prisma calls
|
||||||
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
||||||
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
mock_agent_graph.return_value.find_first = AsyncMock(return_value=mock_agent)
|
||||||
|
|
||||||
# Mock transaction context manager
|
# Mock transaction context manager
|
||||||
mock_tx = mocker.MagicMock()
|
mock_tx = mocker.MagicMock()
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"backend.api.features.store.db.transaction",
|
"backend.api.features.store.db.transaction",
|
||||||
return_value=mocker.AsyncMock(
|
return_value=AsyncMock(
|
||||||
__aenter__=mocker.AsyncMock(return_value=mock_tx),
|
__aenter__=AsyncMock(return_value=mock_tx),
|
||||||
__aexit__=mocker.AsyncMock(return_value=False),
|
__aexit__=AsyncMock(return_value=False),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_sl = mocker.patch("prisma.models.StoreListing.prisma")
|
mock_sl = mocker.patch("prisma.models.StoreListing.prisma")
|
||||||
mock_sl.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
mock_sl.return_value.find_unique = AsyncMock(return_value=None)
|
||||||
|
|
||||||
mock_slv = mocker.patch("prisma.models.StoreListingVersion.prisma")
|
mock_slv = mocker.patch("prisma.models.StoreListingVersion.prisma")
|
||||||
mock_slv.return_value.create = mocker.AsyncMock(return_value=mock_version)
|
mock_slv.return_value.create = AsyncMock(return_value=mock_version)
|
||||||
|
|
||||||
# Call function
|
# Call function
|
||||||
result = await db.create_store_submission(
|
result = await db.create_store_submission(
|
||||||
@@ -292,10 +293,8 @@ async def test_update_profile(mocker):
|
|||||||
|
|
||||||
# Mock prisma calls
|
# Mock prisma calls
|
||||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
mock_profile_db.return_value.find_first = AsyncMock(return_value=mock_profile)
|
||||||
return_value=mock_profile
|
mock_profile_db.return_value.update = AsyncMock(return_value=mock_profile)
|
||||||
)
|
|
||||||
mock_profile_db.return_value.update = mocker.AsyncMock(return_value=mock_profile)
|
|
||||||
|
|
||||||
# Test data
|
# Test data
|
||||||
profile = Profile(
|
profile = Profile(
|
||||||
@@ -336,9 +335,7 @@ async def test_get_user_profile(mocker):
|
|||||||
|
|
||||||
# Mock prisma calls
|
# Mock prisma calls
|
||||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
mock_profile_db.return_value.find_first = AsyncMock(return_value=mock_profile)
|
||||||
return_value=mock_profile
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call function
|
# Call function
|
||||||
result = await db.get_user_profile("user-id")
|
result = await db.get_user_profile("user-id")
|
||||||
@@ -396,3 +393,38 @@ async def test_get_store_agents_search_category_array_injection():
|
|||||||
# Verify the query executed without error
|
# Verify the query executed without error
|
||||||
# Category should be parameterized, preventing SQL injection
|
# Category should be parameterized, preventing SQL injection
|
||||||
assert isinstance(result.agents, list)
|
assert isinstance(result.agents, list)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_get_store_creators_only_returns_approved(mocker):
|
||||||
|
mock_creators = [
|
||||||
|
prisma.models.Creator(
|
||||||
|
name="Creator One",
|
||||||
|
username="creator1",
|
||||||
|
description="desc",
|
||||||
|
links=["link1"],
|
||||||
|
avatar_url="avatar.jpg",
|
||||||
|
num_agents=1,
|
||||||
|
agent_rating=4.5,
|
||||||
|
agent_runs=10,
|
||||||
|
top_categories=["test"],
|
||||||
|
is_featured=False,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_creator = mocker.patch("prisma.models.Creator.prisma")
|
||||||
|
mock_creator.return_value.find_many = AsyncMock(return_value=mock_creators)
|
||||||
|
mock_creator.return_value.count = AsyncMock(return_value=1)
|
||||||
|
|
||||||
|
result = await db.get_store_creators()
|
||||||
|
|
||||||
|
assert len(result.creators) == 1
|
||||||
|
assert result.creators[0].username == "creator1"
|
||||||
|
|
||||||
|
mock_creator.return_value.find_many.assert_called_once()
|
||||||
|
mock_creator.return_value.count.assert_called_once()
|
||||||
|
|
||||||
|
_, find_kwargs = mock_creator.return_value.find_many.call_args
|
||||||
|
_, count_kwargs = mock_creator.return_value.count.call_args
|
||||||
|
assert find_kwargs["where"]["num_agents"] == {"gt": 0}
|
||||||
|
assert count_kwargs["where"]["num_agents"] == {"gt": 0}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,8 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Annotated, Any, Literal, Sequence, get_args
|
from typing import Annotated, Any, Literal, Sequence, cast, get_args
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
import stripe
|
import stripe
|
||||||
@@ -25,10 +26,11 @@ from fastapi import (
|
|||||||
)
|
)
|
||||||
from fastapi.concurrency import run_in_threadpool
|
from fastapi.concurrency import run_in_threadpool
|
||||||
from prisma.enums import SubscriptionTier
|
from prisma.enums import SubscriptionTier
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||||
from typing_extensions import Optional, TypedDict
|
from typing_extensions import Optional, TypedDict
|
||||||
|
|
||||||
|
from backend.api.features.workspace.routes import create_file_download_response
|
||||||
from backend.api.model import (
|
from backend.api.model import (
|
||||||
CreateAPIKeyRequest,
|
CreateAPIKeyRequest,
|
||||||
CreateAPIKeyResponse,
|
CreateAPIKeyResponse,
|
||||||
@@ -42,23 +44,33 @@ from backend.api.model import (
|
|||||||
UploadFileResponse,
|
UploadFileResponse,
|
||||||
)
|
)
|
||||||
from backend.blocks import get_block, get_blocks
|
from backend.blocks import get_block, get_blocks
|
||||||
|
from backend.copilot.rate_limit import get_tier_multipliers
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.auth import api_key as api_key_db
|
from backend.data.auth import api_key as api_key_db
|
||||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||||
from backend.data.credit import (
|
from backend.data.credit import (
|
||||||
AutoTopUpConfig,
|
AutoTopUpConfig,
|
||||||
|
PendingChangeUnknown,
|
||||||
RefundRequest,
|
RefundRequest,
|
||||||
TransactionHistory,
|
TransactionHistory,
|
||||||
UserCredit,
|
UserCredit,
|
||||||
cancel_stripe_subscription,
|
cancel_stripe_subscription,
|
||||||
create_subscription_checkout,
|
create_subscription_checkout,
|
||||||
|
get_active_subscription_period_end,
|
||||||
get_auto_top_up,
|
get_auto_top_up,
|
||||||
|
get_pending_subscription_change,
|
||||||
|
get_proration_credit_cents,
|
||||||
get_subscription_price_id,
|
get_subscription_price_id,
|
||||||
get_user_credit_model,
|
get_user_credit_model,
|
||||||
|
handle_subscription_payment_failure,
|
||||||
|
handle_subscription_payment_success,
|
||||||
|
modify_stripe_subscription_for_tier,
|
||||||
|
release_pending_subscription_schedule,
|
||||||
set_auto_top_up,
|
set_auto_top_up,
|
||||||
set_subscription_tier,
|
set_subscription_tier,
|
||||||
sync_subscription_from_stripe,
|
sync_subscription_from_stripe,
|
||||||
|
sync_subscription_schedule_from_stripe,
|
||||||
)
|
)
|
||||||
from backend.data.graph import GraphSettings
|
from backend.data.graph import GraphSettings
|
||||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||||
@@ -88,6 +100,7 @@ from backend.data.user import (
|
|||||||
update_user_notification_preference,
|
update_user_notification_preference,
|
||||||
update_user_timezone,
|
update_user_timezone,
|
||||||
)
|
)
|
||||||
|
from backend.data.workspace import get_workspace_file_by_id
|
||||||
from backend.executor import scheduler
|
from backend.executor import scheduler
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||||
@@ -689,19 +702,113 @@ async def get_user_auto_top_up(
|
|||||||
|
|
||||||
|
|
||||||
class SubscriptionTierRequest(BaseModel):
|
class SubscriptionTierRequest(BaseModel):
|
||||||
tier: Literal["FREE", "PRO", "BUSINESS"]
|
tier: Literal["NO_TIER", "BASIC", "PRO", "MAX", "BUSINESS"]
|
||||||
success_url: str = ""
|
success_url: str = ""
|
||||||
cancel_url: str = ""
|
cancel_url: str = ""
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionCheckoutResponse(BaseModel):
|
|
||||||
url: str
|
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionStatusResponse(BaseModel):
|
class SubscriptionStatusResponse(BaseModel):
|
||||||
tier: str
|
tier: Literal["NO_TIER", "BASIC", "PRO", "MAX", "BUSINESS", "ENTERPRISE"]
|
||||||
monthly_cost: int
|
monthly_cost: int # amount in cents (Stripe convention)
|
||||||
tier_costs: dict[str, int]
|
tier_costs: dict[str, int] # tier name -> amount in cents
|
||||||
|
tier_multipliers: dict[str, float] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description=(
|
||||||
|
"Tier → rate-limit multiplier. Covers the same tiers listed in"
|
||||||
|
" ``tier_costs`` so the frontend can render rate-limit badges"
|
||||||
|
" relative to the lowest visible tier without knowing backend"
|
||||||
|
" defaults."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
proration_credit_cents: int # unused portion of current sub to convert on upgrade
|
||||||
|
has_active_stripe_subscription: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"True when the user has an active/trialing Stripe subscription. The"
|
||||||
|
" frontend uses this to branch upgrade UX: modify-in-place + saved-card"
|
||||||
|
" auto-charge when True, redirect to Stripe Checkout when False."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
current_period_end: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Unix timestamp of the active subscription's current_period_end. Used"
|
||||||
|
" to show the date Stripe will issue the next invoice (with prorated"
|
||||||
|
" upgrade charges, if any). None when no active sub."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
pending_tier: Optional[Literal["NO_TIER", "BASIC", "PRO", "MAX", "BUSINESS"]] = None
|
||||||
|
pending_tier_effective_at: Optional[datetime] = None
|
||||||
|
url: str = Field(
|
||||||
|
default="",
|
||||||
|
description=(
|
||||||
|
"Populated only when POST /credits/subscription starts a Stripe Checkout"
|
||||||
|
" Session (BASIC → paid upgrade). Empty string in all other branches —"
|
||||||
|
" the client redirects to this URL when non-empty."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_checkout_redirect_url(url: str) -> bool:
|
||||||
|
"""Return True if `url` matches the configured frontend origin.
|
||||||
|
|
||||||
|
Prevents open-redirect: attackers must not be able to supply arbitrary
|
||||||
|
success_url/cancel_url that Stripe will redirect users to after checkout.
|
||||||
|
|
||||||
|
Pre-parse rejection rules (applied before urlparse):
|
||||||
|
- Backslashes (``\\``) are normalised differently across parsers/browsers.
|
||||||
|
- Control characters (U+0000–U+001F) are not valid in URLs and may confuse
|
||||||
|
some URL-parsing implementations.
|
||||||
|
"""
|
||||||
|
# Reject characters that can confuse URL parsers before any parsing.
|
||||||
|
if "\\" in url:
|
||||||
|
return False
|
||||||
|
if any(ord(c) < 0x20 for c in url):
|
||||||
|
return False
|
||||||
|
|
||||||
|
allowed = settings.config.frontend_base_url or settings.config.platform_base_url
|
||||||
|
if not allowed:
|
||||||
|
# No configured origin — refuse to validate rather than allow arbitrary URLs.
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
allowed_parsed = urlparse(allowed)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
if parsed.scheme not in ("http", "https"):
|
||||||
|
return False
|
||||||
|
# Reject ``user:pass@host`` authority tricks — ``@`` in the netloc component
|
||||||
|
# can trick browsers into connecting to a different host than displayed.
|
||||||
|
# ``@`` in query/fragment is harmless and must be allowed.
|
||||||
|
if "@" in parsed.netloc:
|
||||||
|
return False
|
||||||
|
return (
|
||||||
|
parsed.scheme == allowed_parsed.scheme
|
||||||
|
and parsed.netloc == allowed_parsed.netloc
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cached(ttl_seconds=300, maxsize=32, cache_none=False)
|
||||||
|
async def _get_stripe_price_amount(price_id: str) -> int | None:
|
||||||
|
"""Return the unit_amount (cents) for a Stripe Price ID, cached for 5 minutes.
|
||||||
|
|
||||||
|
Returns ``None`` on transient Stripe errors. ``cache_none=False`` opts out
|
||||||
|
of caching the ``None`` sentinel so the next request retries Stripe instead
|
||||||
|
of being served a stale "no price" for the rest of the TTL window. Callers
|
||||||
|
should treat ``None`` as an unknown price and fall back to 0.
|
||||||
|
|
||||||
|
Stripe prices rarely change; caching avoids a ~200-600 ms Stripe round-trip on
|
||||||
|
every GET /credits/subscription page load and reduces quota consumption.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
|
||||||
|
return price.unit_amount or 0
|
||||||
|
except stripe.StripeError:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to retrieve Stripe price %s — returning None (not cached)",
|
||||||
|
price_id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
@@ -715,34 +822,89 @@ async def get_subscription_status(
|
|||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
) -> SubscriptionStatusResponse:
|
) -> SubscriptionStatusResponse:
|
||||||
user = await get_user_by_id(user_id)
|
user = await get_user_by_id(user_id)
|
||||||
tier = user.subscription_tier or SubscriptionTier.FREE
|
tier = user.subscription_tier or SubscriptionTier.NO_TIER
|
||||||
|
|
||||||
paid_tiers = [SubscriptionTier.PRO, SubscriptionTier.BUSINESS]
|
# Tiers that *can* have a Stripe price configured (and therefore appear
|
||||||
|
# in the tier picker if the LD flag exposes a price-id). NO_TIER is not
|
||||||
|
# priceable — it's the implicit "no active subscription" state.
|
||||||
|
priceable_tiers = [
|
||||||
|
SubscriptionTier.BASIC,
|
||||||
|
SubscriptionTier.PRO,
|
||||||
|
SubscriptionTier.MAX,
|
||||||
|
SubscriptionTier.BUSINESS,
|
||||||
|
]
|
||||||
price_ids = await asyncio.gather(
|
price_ids = await asyncio.gather(
|
||||||
*[get_subscription_price_id(t) for t in paid_tiers]
|
*[get_subscription_price_id(t) for t in priceable_tiers]
|
||||||
)
|
)
|
||||||
|
|
||||||
tier_costs: dict[str, int] = {"FREE": 0, "ENTERPRISE": 0}
|
async def _cost(pid: str | None) -> int:
|
||||||
for t, price_id in zip(paid_tiers, price_ids):
|
return (await _get_stripe_price_amount(pid) or 0) if pid else 0
|
||||||
cost = 0
|
|
||||||
if price_id:
|
|
||||||
try:
|
|
||||||
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
|
|
||||||
cost = price.unit_amount or 0
|
|
||||||
except stripe.StripeError:
|
|
||||||
pass
|
|
||||||
tier_costs[t.value] = cost
|
|
||||||
|
|
||||||
return SubscriptionStatusResponse(
|
costs = await asyncio.gather(*[_cost(pid) for pid in price_ids])
|
||||||
|
|
||||||
|
tier_costs: dict[str, int] = {}
|
||||||
|
for t, pid, cost in zip(priceable_tiers, price_ids, costs):
|
||||||
|
if pid:
|
||||||
|
tier_costs[t.value] = cost
|
||||||
|
|
||||||
|
# Expose the effective rate-limit multipliers alongside prices so the
|
||||||
|
# frontend can render "Nx rate limits" relative to the lowest visible
|
||||||
|
# tier without hard-coding backend defaults. Only emit entries for tiers
|
||||||
|
# that land in ``tier_costs`` — rows hidden at the price layer must stay
|
||||||
|
# hidden in the multiplier layer too.
|
||||||
|
multipliers = await get_tier_multipliers()
|
||||||
|
tier_multipliers: dict[str, float] = {
|
||||||
|
t.value: multipliers.get(t, 1.0)
|
||||||
|
for t in priceable_tiers
|
||||||
|
if t.value in tier_costs
|
||||||
|
}
|
||||||
|
|
||||||
|
current_monthly_cost = tier_costs.get(tier.value, 0)
|
||||||
|
proration_credit, current_period_end = await asyncio.gather(
|
||||||
|
get_proration_credit_cents(user_id, current_monthly_cost),
|
||||||
|
get_active_subscription_period_end(user_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
pending = await get_pending_subscription_change(user_id)
|
||||||
|
except (stripe.StripeError, PendingChangeUnknown):
|
||||||
|
# Swallow Stripe-side failures (rate limits, transient network) AND
|
||||||
|
# PendingChangeUnknown (LaunchDarkly price-id lookup failed). Both
|
||||||
|
# propagate past the cache so the next request retries fresh instead
|
||||||
|
# of serving a stale None for the TTL window. Let real bugs (KeyError,
|
||||||
|
# AttributeError, etc.) propagate so they surface in Sentry.
|
||||||
|
logger.exception(
|
||||||
|
"get_subscription_status: failed to resolve pending change for user %s",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
pending = None
|
||||||
|
|
||||||
|
response = SubscriptionStatusResponse(
|
||||||
tier=tier.value,
|
tier=tier.value,
|
||||||
monthly_cost=tier_costs.get(tier.value, 0),
|
monthly_cost=current_monthly_cost,
|
||||||
tier_costs=tier_costs,
|
tier_costs=tier_costs,
|
||||||
|
tier_multipliers=tier_multipliers,
|
||||||
|
proration_credit_cents=proration_credit,
|
||||||
|
has_active_stripe_subscription=current_period_end is not None,
|
||||||
|
current_period_end=current_period_end,
|
||||||
)
|
)
|
||||||
|
if pending is not None:
|
||||||
|
pending_tier_enum, pending_effective_at = pending
|
||||||
|
if pending_tier_enum in (
|
||||||
|
SubscriptionTier.NO_TIER,
|
||||||
|
SubscriptionTier.BASIC,
|
||||||
|
SubscriptionTier.PRO,
|
||||||
|
SubscriptionTier.MAX,
|
||||||
|
SubscriptionTier.BUSINESS,
|
||||||
|
):
|
||||||
|
response.pending_tier = pending_tier_enum.value
|
||||||
|
response.pending_tier_effective_at = pending_effective_at
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
@v1_router.post(
|
@v1_router.post(
|
||||||
path="/credits/subscription",
|
path="/credits/subscription",
|
||||||
summary="Start a Stripe Checkout session to upgrade subscription tier",
|
summary="Update subscription tier or start a Stripe Checkout session",
|
||||||
operation_id="updateSubscriptionTier",
|
operation_id="updateSubscriptionTier",
|
||||||
tags=["credits"],
|
tags=["credits"],
|
||||||
dependencies=[Security(requires_user)],
|
dependencies=[Security(requires_user)],
|
||||||
@@ -750,40 +912,172 @@ async def get_subscription_status(
|
|||||||
async def update_subscription_tier(
|
async def update_subscription_tier(
|
||||||
request: SubscriptionTierRequest,
|
request: SubscriptionTierRequest,
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
) -> SubscriptionCheckoutResponse:
|
) -> SubscriptionStatusResponse:
|
||||||
# Pydantic validates tier is one of FREE/PRO/BUSINESS via Literal type.
|
# Pydantic validates tier is one of BASIC/PRO/MAX/BUSINESS via Literal type.
|
||||||
tier = SubscriptionTier(request.tier)
|
tier = SubscriptionTier(request.tier)
|
||||||
|
|
||||||
# ENTERPRISE tier is admin-managed — block self-service changes from ENTERPRISE users.
|
# ENTERPRISE tier is admin-managed — block self-service changes from ENTERPRISE users.
|
||||||
user = await get_user_by_id(user_id)
|
user = await get_user_by_id(user_id)
|
||||||
if (user.subscription_tier or SubscriptionTier.FREE) == SubscriptionTier.ENTERPRISE:
|
if (
|
||||||
|
user.subscription_tier or SubscriptionTier.NO_TIER
|
||||||
|
) == SubscriptionTier.ENTERPRISE:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
detail="ENTERPRISE subscription changes must be managed by an administrator",
|
detail="ENTERPRISE subscription changes must be managed by an administrator",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Same-tier request = "stay on my current tier" = cancel any pending
|
||||||
|
# scheduled change (paid→paid downgrade or paid→BASIC cancel). This is the
|
||||||
|
# collapsed behaviour that replaces the old /credits/subscription/cancel-pending
|
||||||
|
# route. Safe when no pending change exists: release_pending_subscription_schedule
|
||||||
|
# returns False and we simply return the current status.
|
||||||
|
if (user.subscription_tier or SubscriptionTier.NO_TIER) == tier:
|
||||||
|
try:
|
||||||
|
await release_pending_subscription_schedule(user_id)
|
||||||
|
except stripe.StripeError as e:
|
||||||
|
logger.exception(
|
||||||
|
"Stripe error releasing pending subscription change for user %s: %s",
|
||||||
|
user_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=(
|
||||||
|
"Unable to cancel the pending subscription change right now. "
|
||||||
|
"Please try again or contact support."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return await get_subscription_status(user_id)
|
||||||
|
|
||||||
payment_enabled = await is_feature_enabled(
|
payment_enabled = await is_feature_enabled(
|
||||||
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
|
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Downgrade to FREE: cancel active Stripe subscription, then update the DB tier.
|
target_price_id = await get_subscription_price_id(tier)
|
||||||
if tier == SubscriptionTier.FREE:
|
|
||||||
|
# Cancel: target NO_TIER. Schedule Stripe cancellation at period end;
|
||||||
|
# cancel_at_period_end=True lets the webhook flip the DB tier. No active
|
||||||
|
# sub (admin-granted or never-paid) or payment disabled → DB flip.
|
||||||
|
# NO_TIER is never priceable, so this branch always fires for cancel
|
||||||
|
# requests regardless of LD config.
|
||||||
|
if tier == SubscriptionTier.NO_TIER:
|
||||||
if payment_enabled:
|
if payment_enabled:
|
||||||
await cancel_stripe_subscription(user_id)
|
try:
|
||||||
|
had_subscription = await cancel_stripe_subscription(user_id)
|
||||||
|
except stripe.StripeError as e:
|
||||||
|
logger.exception(
|
||||||
|
"Stripe error cancelling subscription for user %s: %s",
|
||||||
|
user_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=(
|
||||||
|
"Unable to cancel your subscription right now. "
|
||||||
|
"Please try again or contact support."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if not had_subscription:
|
||||||
|
await set_subscription_tier(user_id, tier)
|
||||||
|
return await get_subscription_status(user_id)
|
||||||
await set_subscription_tier(user_id, tier)
|
await set_subscription_tier(user_id, tier)
|
||||||
return SubscriptionCheckoutResponse(url="")
|
return await get_subscription_status(user_id)
|
||||||
|
|
||||||
# Beta users (payment not enabled) → update tier directly without Stripe.
|
|
||||||
if not payment_enabled:
|
if not payment_enabled:
|
||||||
await set_subscription_tier(user_id, tier)
|
raise HTTPException(
|
||||||
return SubscriptionCheckoutResponse(url="")
|
status_code=422,
|
||||||
|
detail=f"Subscription not available for tier {tier.value}",
|
||||||
|
)
|
||||||
|
|
||||||
# Paid upgrade → create Stripe Checkout Session.
|
# Target has no LD price — not provisionable (matches the GET hiding).
|
||||||
|
if target_price_id is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=422,
|
||||||
|
detail=f"Subscription not available for tier {tier.value}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Modify in place if there's a sub; else fall through to Checkout below.
|
||||||
|
try:
|
||||||
|
modified = await modify_stripe_subscription_for_tier(user_id, tier)
|
||||||
|
if modified:
|
||||||
|
return await get_subscription_status(user_id)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=422, detail=str(e))
|
||||||
|
except stripe.InvalidRequestError as e:
|
||||||
|
# Stripe rejects schedule modify when phases mix currencies, e.g. the
|
||||||
|
# active sub was checked out in GBP but the target tier's Price is
|
||||||
|
# USD-only. 502 reads as outage; surface a 422 with a specific message
|
||||||
|
# so the user/admin can see what to fix in Stripe.
|
||||||
|
msg = str(e)
|
||||||
|
if "currency" in msg.lower():
|
||||||
|
logger.warning(
|
||||||
|
"Currency mismatch on tier change for user %s: %s", user_id, msg
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=422,
|
||||||
|
detail=(
|
||||||
|
"Tier change unavailable for your current billing currency."
|
||||||
|
" Please contact support — the target tier needs to be"
|
||||||
|
" configured for your currency in Stripe before this"
|
||||||
|
" change can go through."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
logger.exception(
|
||||||
|
"Stripe error modifying subscription for user %s: %s", user_id, e
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=(
|
||||||
|
"Unable to update your subscription right now. "
|
||||||
|
"Please try again or contact support."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except stripe.StripeError as e:
|
||||||
|
logger.exception(
|
||||||
|
"Stripe error modifying subscription for user %s: %s", user_id, e
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=(
|
||||||
|
"Unable to update your subscription right now. "
|
||||||
|
"Please try again or contact support."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# No active Stripe subscription → create Stripe Checkout Session.
|
||||||
if not request.success_url or not request.cancel_url:
|
if not request.success_url or not request.cancel_url:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=422,
|
status_code=422,
|
||||||
detail="success_url and cancel_url are required for paid tier upgrades",
|
detail="success_url and cancel_url are required for paid tier upgrades",
|
||||||
)
|
)
|
||||||
|
# Open-redirect protection: both URLs must point to the configured frontend
|
||||||
|
# origin, otherwise an attacker could use our Stripe integration as a
|
||||||
|
# redirector to arbitrary phishing sites.
|
||||||
|
#
|
||||||
|
# Fail early with a clear 503 if the server is misconfigured (neither
|
||||||
|
# frontend_base_url nor platform_base_url set), so operators get an
|
||||||
|
# actionable error instead of the misleading "must match the platform
|
||||||
|
# frontend origin" 422 that _validate_checkout_redirect_url would otherwise
|
||||||
|
# produce when `allowed` is empty.
|
||||||
|
if not (settings.config.frontend_base_url or settings.config.platform_base_url):
|
||||||
|
logger.error(
|
||||||
|
"update_subscription_tier: neither frontend_base_url nor "
|
||||||
|
"platform_base_url is configured; cannot validate checkout redirect URLs"
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail=(
|
||||||
|
"Payment redirect URLs cannot be validated: "
|
||||||
|
"frontend_base_url or platform_base_url must be set on the server."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if not _validate_checkout_redirect_url(
|
||||||
|
request.success_url
|
||||||
|
) or not _validate_checkout_redirect_url(request.cancel_url):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=422,
|
||||||
|
detail="success_url and cancel_url must match the platform frontend origin",
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
url = await create_subscription_checkout(
|
url = await create_subscription_checkout(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -791,54 +1085,116 @@ async def update_subscription_tier(
|
|||||||
success_url=request.success_url,
|
success_url=request.success_url,
|
||||||
cancel_url=request.cancel_url,
|
cancel_url=request.cancel_url,
|
||||||
)
|
)
|
||||||
except (ValueError, stripe.StripeError) as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=422, detail=str(e))
|
raise HTTPException(status_code=422, detail=str(e))
|
||||||
|
except stripe.StripeError as e:
|
||||||
|
logger.exception(
|
||||||
|
"Stripe error creating checkout session for user %s: %s", user_id, e
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=(
|
||||||
|
"Unable to start checkout right now. "
|
||||||
|
"Please try again or contact support."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
return SubscriptionCheckoutResponse(url=url)
|
status = await get_subscription_status(user_id)
|
||||||
|
status.url = url
|
||||||
|
return status
|
||||||
|
|
||||||
|
|
||||||
@v1_router.post(
|
@v1_router.post(
|
||||||
path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"]
|
path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"]
|
||||||
)
|
)
|
||||||
async def stripe_webhook(request: Request):
|
async def stripe_webhook(request: Request):
|
||||||
|
webhook_secret = settings.secrets.stripe_webhook_secret
|
||||||
|
if not webhook_secret:
|
||||||
|
# Guard: an empty secret allows HMAC forgery (attacker can compute a valid
|
||||||
|
# signature over the same empty key). Reject all webhook calls when unconfigured.
|
||||||
|
logger.error(
|
||||||
|
"stripe_webhook: STRIPE_WEBHOOK_SECRET is not configured — "
|
||||||
|
"rejecting request to prevent signature bypass"
|
||||||
|
)
|
||||||
|
raise HTTPException(status_code=503, detail="Webhook not configured")
|
||||||
|
|
||||||
# Get the raw request body
|
# Get the raw request body
|
||||||
payload = await request.body()
|
payload = await request.body()
|
||||||
# Get the signature header
|
# Get the signature header
|
||||||
sig_header = request.headers.get("stripe-signature")
|
sig_header = request.headers.get("stripe-signature")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
event = stripe.Webhook.construct_event(
|
event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret)
|
||||||
payload, sig_header, settings.secrets.stripe_webhook_secret
|
except ValueError:
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
# Invalid payload
|
# Invalid payload
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=400, detail="Invalid payload")
|
||||||
status_code=400, detail=f"Invalid payload: {str(e) or type(e).__name__}"
|
except stripe.SignatureVerificationError:
|
||||||
)
|
|
||||||
except stripe.SignatureVerificationError as e:
|
|
||||||
# Invalid signature
|
# Invalid signature
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=400, detail="Invalid signature")
|
||||||
status_code=400, detail=f"Invalid signature: {str(e) or type(e).__name__}"
|
|
||||||
|
# Defensive payload extraction. A malformed payload (missing/non-dict
|
||||||
|
# `data.object`, missing `id`) would otherwise raise KeyError/TypeError
|
||||||
|
# AFTER signature verification — which Stripe interprets as a delivery
|
||||||
|
# failure and retries forever, while spamming Sentry with no useful info.
|
||||||
|
# Acknowledge with 200 and a warning so Stripe stops retrying.
|
||||||
|
event_type = event.get("type", "")
|
||||||
|
event_data = event.get("data") or {}
|
||||||
|
data_object = event_data.get("object") if isinstance(event_data, dict) else None
|
||||||
|
if not isinstance(data_object, dict):
|
||||||
|
logger.warning(
|
||||||
|
"stripe_webhook: %s missing or non-dict data.object; ignoring",
|
||||||
|
event_type,
|
||||||
)
|
)
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
if (
|
if event_type in (
|
||||||
event["type"] == "checkout.session.completed"
|
"checkout.session.completed",
|
||||||
or event["type"] == "checkout.session.async_payment_succeeded"
|
"checkout.session.async_payment_succeeded",
|
||||||
):
|
):
|
||||||
await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"])
|
session_id = data_object.get("id")
|
||||||
|
if not session_id:
|
||||||
|
logger.warning(
|
||||||
|
"stripe_webhook: %s missing data.object.id; ignoring", event_type
|
||||||
|
)
|
||||||
|
return Response(status_code=200)
|
||||||
|
await UserCredit().fulfill_checkout(session_id=session_id)
|
||||||
|
|
||||||
if event["type"] in (
|
if event_type in (
|
||||||
"customer.subscription.created",
|
"customer.subscription.created",
|
||||||
"customer.subscription.updated",
|
"customer.subscription.updated",
|
||||||
"customer.subscription.deleted",
|
"customer.subscription.deleted",
|
||||||
):
|
):
|
||||||
await sync_subscription_from_stripe(event["data"]["object"])
|
await sync_subscription_from_stripe(data_object)
|
||||||
|
|
||||||
if event["type"] == "charge.dispute.created":
|
# `subscription_schedule.updated` is deliberately omitted: our own
|
||||||
await UserCredit().handle_dispute(event["data"]["object"])
|
# `SubscriptionSchedule.create` + `.modify` calls in
|
||||||
|
# `_schedule_downgrade_at_period_end` would fire that event right back at us
|
||||||
|
# and loop redundant traffic through this handler. We only care about state
|
||||||
|
# transitions (released / completed); phase advance to the new price is
|
||||||
|
# already covered by `customer.subscription.updated`.
|
||||||
|
if event_type in (
|
||||||
|
"subscription_schedule.released",
|
||||||
|
"subscription_schedule.completed",
|
||||||
|
):
|
||||||
|
await sync_subscription_schedule_from_stripe(data_object)
|
||||||
|
|
||||||
if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed":
|
if event_type == "invoice.payment_succeeded":
|
||||||
await UserCredit().deduct_credits(event["data"]["object"])
|
await handle_subscription_payment_success(data_object)
|
||||||
|
|
||||||
|
if event_type == "invoice.payment_failed":
|
||||||
|
await handle_subscription_payment_failure(data_object)
|
||||||
|
|
||||||
|
# `handle_dispute` and `deduct_credits` expect Stripe SDK typed objects
|
||||||
|
# (Dispute/Refund). The Stripe webhook payload's `data.object` is a
|
||||||
|
# StripeObject (a dict subclass) carrying that runtime shape, so we cast
|
||||||
|
# to satisfy the type checker without changing runtime behaviour.
|
||||||
|
if event_type == "charge.dispute.created":
|
||||||
|
await UserCredit().handle_dispute(cast(stripe.Dispute, data_object))
|
||||||
|
|
||||||
|
if event_type == "refund.created" or event_type == "charge.dispute.closed":
|
||||||
|
await UserCredit().deduct_credits(
|
||||||
|
cast("stripe.Refund | stripe.Dispute", data_object)
|
||||||
|
)
|
||||||
|
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
@@ -1422,6 +1778,10 @@ async def enable_execution_sharing(
|
|||||||
# Generate a unique share token
|
# Generate a unique share token
|
||||||
share_token = str(uuid.uuid4())
|
share_token = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Remove stale allowlist records before updating the token — prevents a
|
||||||
|
# window where old records + new token could coexist.
|
||||||
|
await execution_db.delete_shared_execution_files(execution_id=graph_exec_id)
|
||||||
|
|
||||||
# Update the execution with share info
|
# Update the execution with share info
|
||||||
await execution_db.update_graph_execution_share_status(
|
await execution_db.update_graph_execution_share_status(
|
||||||
execution_id=graph_exec_id,
|
execution_id=graph_exec_id,
|
||||||
@@ -1431,6 +1791,14 @@ async def enable_execution_sharing(
|
|||||||
shared_at=datetime.now(timezone.utc),
|
shared_at=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create allowlist of workspace files referenced in outputs
|
||||||
|
await execution_db.create_shared_execution_files(
|
||||||
|
execution_id=graph_exec_id,
|
||||||
|
share_token=share_token,
|
||||||
|
user_id=user_id,
|
||||||
|
outputs=execution.outputs,
|
||||||
|
)
|
||||||
|
|
||||||
# Return the share URL
|
# Return the share URL
|
||||||
frontend_url = settings.config.frontend_base_url or "http://localhost:3000"
|
frontend_url = settings.config.frontend_base_url or "http://localhost:3000"
|
||||||
share_url = f"{frontend_url}/share/{share_token}"
|
share_url = f"{frontend_url}/share/{share_token}"
|
||||||
@@ -1456,6 +1824,9 @@ async def disable_execution_sharing(
|
|||||||
if not execution:
|
if not execution:
|
||||||
raise HTTPException(status_code=404, detail="Execution not found")
|
raise HTTPException(status_code=404, detail="Execution not found")
|
||||||
|
|
||||||
|
# Remove shared file allowlist records
|
||||||
|
await execution_db.delete_shared_execution_files(execution_id=graph_exec_id)
|
||||||
|
|
||||||
# Remove share info
|
# Remove share info
|
||||||
await execution_db.update_graph_execution_share_status(
|
await execution_db.update_graph_execution_share_status(
|
||||||
execution_id=graph_exec_id,
|
execution_id=graph_exec_id,
|
||||||
@@ -1481,6 +1852,43 @@ async def get_shared_execution(
|
|||||||
return execution
|
return execution
|
||||||
|
|
||||||
|
|
||||||
|
@v1_router.get(
|
||||||
|
"/public/shared/{share_token}/files/{file_id}/download",
|
||||||
|
summary="Download a file from a shared execution",
|
||||||
|
operation_id="download_shared_file",
|
||||||
|
tags=["graphs"],
|
||||||
|
)
|
||||||
|
async def download_shared_file(
|
||||||
|
share_token: Annotated[
|
||||||
|
str,
|
||||||
|
Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"),
|
||||||
|
],
|
||||||
|
file_id: Annotated[
|
||||||
|
str,
|
||||||
|
Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"),
|
||||||
|
],
|
||||||
|
) -> Response:
|
||||||
|
"""Download a workspace file from a shared execution (no auth required).
|
||||||
|
|
||||||
|
Validates that the file was explicitly exposed when sharing was enabled.
|
||||||
|
Returns a uniform 404 for all failure modes to prevent enumeration attacks.
|
||||||
|
"""
|
||||||
|
# Single-query validation against the allowlist
|
||||||
|
execution_id = await execution_db.get_shared_execution_file(
|
||||||
|
share_token=share_token, file_id=file_id
|
||||||
|
)
|
||||||
|
if not execution_id:
|
||||||
|
raise HTTPException(status_code=404, detail="Not found")
|
||||||
|
|
||||||
|
# Look up the actual file (no workspace scoping needed — the allowlist
|
||||||
|
# already validated that this file belongs to the shared execution)
|
||||||
|
file = await get_workspace_file_by_id(file_id)
|
||||||
|
if not file:
|
||||||
|
raise HTTPException(status_code=404, detail="Not found")
|
||||||
|
|
||||||
|
return await create_file_download_response(file, inline=True)
|
||||||
|
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
##################### Schedules ########################
|
##################### Schedules ########################
|
||||||
########################################################
|
########################################################
|
||||||
|
|||||||
157
autogpt_platform/backend/backend/api/features/v1_share_test.py
Normal file
157
autogpt_platform/backend/backend/api/features/v1_share_test.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""Tests for the public shared file download endpoint."""
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from starlette.responses import Response
|
||||||
|
|
||||||
|
from backend.api.features.v1 import v1_router
|
||||||
|
from backend.data.workspace import WorkspaceFile
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(v1_router, prefix="/api")
|
||||||
|
|
||||||
|
VALID_TOKEN = "550e8400-e29b-41d4-a716-446655440000"
|
||||||
|
VALID_FILE_ID = "6ba7b810-9dad-11d1-80b4-00c04fd430c8"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_workspace_file(**overrides) -> WorkspaceFile:
|
||||||
|
defaults = {
|
||||||
|
"id": VALID_FILE_ID,
|
||||||
|
"workspace_id": "ws-001",
|
||||||
|
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||||
|
"updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||||
|
"name": "image.png",
|
||||||
|
"path": "/image.png",
|
||||||
|
"storage_path": "local://uploads/image.png",
|
||||||
|
"mime_type": "image/png",
|
||||||
|
"size_bytes": 4,
|
||||||
|
"checksum": None,
|
||||||
|
"is_deleted": False,
|
||||||
|
"deleted_at": None,
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
defaults.update(overrides)
|
||||||
|
return WorkspaceFile(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_download_response(**kwargs):
|
||||||
|
"""Return an AsyncMock that resolves to a Response with inline disposition."""
|
||||||
|
|
||||||
|
async def _handler(file, *, inline=False):
|
||||||
|
return Response(
|
||||||
|
content=b"\x89PNG",
|
||||||
|
media_type="image/png",
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": (
|
||||||
|
'inline; filename="image.png"'
|
||||||
|
if inline
|
||||||
|
else 'attachment; filename="image.png"'
|
||||||
|
),
|
||||||
|
"Content-Length": "4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return _handler
|
||||||
|
|
||||||
|
|
||||||
|
class TestDownloadSharedFile:
|
||||||
|
"""Tests for GET /api/public/shared/{token}/files/{id}/download."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _client(self):
|
||||||
|
self.client = TestClient(app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
def test_valid_token_and_file_returns_inline_content(self):
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value="exec-123",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"backend.api.features.v1.get_workspace_file_by_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=_make_workspace_file(),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"backend.api.features.v1.create_file_download_response",
|
||||||
|
side_effect=_mock_download_response(),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
response = self.client.get(
|
||||||
|
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.content == b"\x89PNG"
|
||||||
|
assert "inline" in response.headers["Content-Disposition"]
|
||||||
|
|
||||||
|
def test_invalid_token_format_returns_422(self):
|
||||||
|
response = self.client.get(
|
||||||
|
f"/api/public/shared/not-a-uuid/files/{VALID_FILE_ID}/download"
|
||||||
|
)
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_token_not_in_allowlist_returns_404(self):
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=None,
|
||||||
|
):
|
||||||
|
response = self.client.get(
|
||||||
|
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||||
|
)
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
def test_file_missing_from_workspace_returns_404(self):
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value="exec-123",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"backend.api.features.v1.get_workspace_file_by_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
response = self.client.get(
|
||||||
|
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||||
|
)
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
def test_uniform_404_prevents_enumeration(self):
|
||||||
|
"""Both failure modes produce identical 404 — no information leak."""
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=None,
|
||||||
|
):
|
||||||
|
resp_no_allow = self.client.get(
|
||||||
|
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value="exec-123",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"backend.api.features.v1.get_workspace_file_by_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
resp_no_file = self.client.get(
|
||||||
|
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp_no_allow.status_code == 404
|
||||||
|
assert resp_no_file.status_code == 404
|
||||||
|
assert resp_no_allow.json() == resp_no_file.json()
|
||||||
@@ -29,7 +29,9 @@ from backend.util.workspace import WorkspaceManager
|
|||||||
from backend.util.workspace_storage import get_workspace_storage
|
from backend.util.workspace_storage import get_workspace_storage
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_filename_for_header(filename: str) -> str:
|
def _sanitize_filename_for_header(
|
||||||
|
filename: str, disposition: str = "attachment"
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Sanitize filename for Content-Disposition header to prevent header injection.
|
Sanitize filename for Content-Disposition header to prevent header injection.
|
||||||
|
|
||||||
@@ -44,11 +46,11 @@ def _sanitize_filename_for_header(filename: str) -> str:
|
|||||||
# Check if filename has non-ASCII characters
|
# Check if filename has non-ASCII characters
|
||||||
try:
|
try:
|
||||||
sanitized.encode("ascii")
|
sanitized.encode("ascii")
|
||||||
return f'attachment; filename="{sanitized}"'
|
return f'{disposition}; filename="{sanitized}"'
|
||||||
except UnicodeEncodeError:
|
except UnicodeEncodeError:
|
||||||
# Use RFC5987 encoding for UTF-8 filenames
|
# Use RFC5987 encoding for UTF-8 filenames
|
||||||
encoded = quote(sanitized, safe="")
|
encoded = quote(sanitized, safe="")
|
||||||
return f"attachment; filename*=UTF-8''{encoded}"
|
return f"{disposition}; filename*=UTF-8''{encoded}"
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -58,19 +60,26 @@ router = fastapi.APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
|
def _create_streaming_response(
|
||||||
|
content: bytes, file: WorkspaceFile, *, inline: bool = False
|
||||||
|
) -> Response:
|
||||||
"""Create a streaming response for file content."""
|
"""Create a streaming response for file content."""
|
||||||
|
disposition = _sanitize_filename_for_header(
|
||||||
|
file.name, disposition="inline" if inline else "attachment"
|
||||||
|
)
|
||||||
return Response(
|
return Response(
|
||||||
content=content,
|
content=content,
|
||||||
media_type=file.mime_type,
|
media_type=file.mime_type,
|
||||||
headers={
|
headers={
|
||||||
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
"Content-Disposition": disposition,
|
||||||
"Content-Length": str(len(content)),
|
"Content-Length": str(len(content)),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
async def create_file_download_response(
|
||||||
|
file: WorkspaceFile, *, inline: bool = False
|
||||||
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Create a download response for a workspace file.
|
Create a download response for a workspace file.
|
||||||
|
|
||||||
@@ -82,7 +91,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
|||||||
# For local storage, stream the file directly
|
# For local storage, stream the file directly
|
||||||
if file.storage_path.startswith("local://"):
|
if file.storage_path.startswith("local://"):
|
||||||
content = await storage.retrieve(file.storage_path)
|
content = await storage.retrieve(file.storage_path)
|
||||||
return _create_streaming_response(content, file)
|
return _create_streaming_response(content, file, inline=inline)
|
||||||
|
|
||||||
# For GCS, try to redirect to signed URL, fall back to streaming
|
# For GCS, try to redirect to signed URL, fall back to streaming
|
||||||
try:
|
try:
|
||||||
@@ -90,7 +99,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
|||||||
# If we got back an API path (fallback), stream directly instead
|
# If we got back an API path (fallback), stream directly instead
|
||||||
if url.startswith("/api/"):
|
if url.startswith("/api/"):
|
||||||
content = await storage.retrieve(file.storage_path)
|
content = await storage.retrieve(file.storage_path)
|
||||||
return _create_streaming_response(content, file)
|
return _create_streaming_response(content, file, inline=inline)
|
||||||
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log the signed URL failure with context
|
# Log the signed URL failure with context
|
||||||
@@ -102,7 +111,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
|||||||
# Fall back to streaming directly from GCS
|
# Fall back to streaming directly from GCS
|
||||||
try:
|
try:
|
||||||
content = await storage.retrieve(file.storage_path)
|
content = await storage.retrieve(file.storage_path)
|
||||||
return _create_streaming_response(content, file)
|
return _create_streaming_response(content, file, inline=inline)
|
||||||
except Exception as fallback_error:
|
except Exception as fallback_error:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Fallback streaming also failed for file {file.id} "
|
f"Fallback streaming also failed for file {file.id} "
|
||||||
@@ -169,7 +178,7 @@ async def download_file(
|
|||||||
if file is None:
|
if file is None:
|
||||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
return await _create_file_download_response(file)
|
return await create_file_download_response(file)
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
|
|||||||
@@ -600,3 +600,221 @@ def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace):
|
|||||||
mock_instance.list_files.assert_called_once_with(
|
mock_instance.list_files.assert_called_once_with(
|
||||||
limit=11, offset=50, include_all_sessions=True
|
limit=11, offset=50, include_all_sessions=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -- _sanitize_filename_for_header tests --
|
||||||
|
|
||||||
|
|
||||||
|
class TestSanitizeFilenameForHeader:
|
||||||
|
def test_simple_ascii_attachment(self):
|
||||||
|
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||||
|
|
||||||
|
assert _sanitize_filename_for_header("report.pdf") == (
|
||||||
|
'attachment; filename="report.pdf"'
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_inline_disposition(self):
|
||||||
|
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||||
|
|
||||||
|
assert _sanitize_filename_for_header("image.png", disposition="inline") == (
|
||||||
|
'inline; filename="image.png"'
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_strips_cr_lf_null(self):
|
||||||
|
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||||
|
|
||||||
|
result = _sanitize_filename_for_header("a\rb\nc\x00d.txt")
|
||||||
|
assert "\r" not in result
|
||||||
|
assert "\n" not in result
|
||||||
|
assert "\x00" not in result
|
||||||
|
assert 'filename="abcd.txt"' in result
|
||||||
|
|
||||||
|
def test_escapes_quotes(self):
|
||||||
|
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||||
|
|
||||||
|
result = _sanitize_filename_for_header('file"name.txt')
|
||||||
|
assert 'filename="file\\"name.txt"' in result
|
||||||
|
|
||||||
|
def test_header_injection_blocked(self):
|
||||||
|
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||||
|
|
||||||
|
result = _sanitize_filename_for_header("evil.txt\r\nX-Injected: true")
|
||||||
|
# CR/LF stripped — the remaining text is safely inside the quoted value
|
||||||
|
assert "\r" not in result
|
||||||
|
assert "\n" not in result
|
||||||
|
assert result == 'attachment; filename="evil.txtX-Injected: true"'
|
||||||
|
|
||||||
|
def test_unicode_uses_rfc5987(self):
|
||||||
|
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||||
|
|
||||||
|
result = _sanitize_filename_for_header("日本語.pdf")
|
||||||
|
assert "filename*=UTF-8''" in result
|
||||||
|
assert "attachment" in result
|
||||||
|
|
||||||
|
def test_unicode_inline(self):
|
||||||
|
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||||
|
|
||||||
|
result = _sanitize_filename_for_header("图片.png", disposition="inline")
|
||||||
|
assert result.startswith("inline; filename*=UTF-8''")
|
||||||
|
|
||||||
|
def test_empty_filename(self):
|
||||||
|
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||||
|
|
||||||
|
result = _sanitize_filename_for_header("")
|
||||||
|
assert result == 'attachment; filename=""'
|
||||||
|
|
||||||
|
|
||||||
|
# -- _create_streaming_response tests --
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateStreamingResponse:
|
||||||
|
def test_attachment_disposition_by_default(self):
|
||||||
|
from backend.api.features.workspace.routes import _create_streaming_response
|
||||||
|
|
||||||
|
file = _make_file(name="data.bin", mime_type="application/octet-stream")
|
||||||
|
response = _create_streaming_response(b"binary-data", file)
|
||||||
|
assert (
|
||||||
|
response.headers["Content-Disposition"] == 'attachment; filename="data.bin"'
|
||||||
|
)
|
||||||
|
assert response.headers["Content-Type"] == "application/octet-stream"
|
||||||
|
assert response.headers["Content-Length"] == "11"
|
||||||
|
assert response.body == b"binary-data"
|
||||||
|
|
||||||
|
def test_inline_disposition(self):
|
||||||
|
from backend.api.features.workspace.routes import _create_streaming_response
|
||||||
|
|
||||||
|
file = _make_file(name="photo.png", mime_type="image/png")
|
||||||
|
response = _create_streaming_response(b"\x89PNG", file, inline=True)
|
||||||
|
assert response.headers["Content-Disposition"] == 'inline; filename="photo.png"'
|
||||||
|
assert response.headers["Content-Type"] == "image/png"
|
||||||
|
|
||||||
|
def test_inline_sanitizes_filename(self):
|
||||||
|
from backend.api.features.workspace.routes import _create_streaming_response
|
||||||
|
|
||||||
|
file = _make_file(name='evil"\r\n.txt', mime_type="text/plain")
|
||||||
|
response = _create_streaming_response(b"data", file, inline=True)
|
||||||
|
assert "\r" not in response.headers["Content-Disposition"]
|
||||||
|
assert "\n" not in response.headers["Content-Disposition"]
|
||||||
|
assert "inline" in response.headers["Content-Disposition"]
|
||||||
|
|
||||||
|
def test_content_length_matches_body(self):
|
||||||
|
from backend.api.features.workspace.routes import _create_streaming_response
|
||||||
|
|
||||||
|
content = b"x" * 1000
|
||||||
|
file = _make_file(name="big.bin", mime_type="application/octet-stream")
|
||||||
|
response = _create_streaming_response(content, file)
|
||||||
|
assert response.headers["Content-Length"] == "1000"
|
||||||
|
|
||||||
|
|
||||||
|
# -- create_file_download_response tests --
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateFileDownloadResponse:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_local_storage_returns_streaming_response(self, mocker):
|
||||||
|
from backend.api.features.workspace.routes import create_file_download_response
|
||||||
|
|
||||||
|
mock_storage = AsyncMock()
|
||||||
|
mock_storage.retrieve.return_value = b"file contents"
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||||
|
return_value=mock_storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
file = _make_file(
|
||||||
|
storage_path="local://uploads/test.txt",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
response = await create_file_download_response(file)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.body == b"file contents"
|
||||||
|
assert "attachment" in response.headers["Content-Disposition"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_local_storage_inline(self, mocker):
|
||||||
|
from backend.api.features.workspace.routes import create_file_download_response
|
||||||
|
|
||||||
|
mock_storage = AsyncMock()
|
||||||
|
mock_storage.retrieve.return_value = b"\x89PNG"
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||||
|
return_value=mock_storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
file = _make_file(
|
||||||
|
storage_path="local://uploads/photo.png",
|
||||||
|
mime_type="image/png",
|
||||||
|
name="photo.png",
|
||||||
|
)
|
||||||
|
response = await create_file_download_response(file, inline=True)
|
||||||
|
assert "inline" in response.headers["Content-Disposition"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gcs_redirect(self, mocker):
|
||||||
|
from backend.api.features.workspace.routes import create_file_download_response
|
||||||
|
|
||||||
|
mock_storage = AsyncMock()
|
||||||
|
mock_storage.get_download_url.return_value = (
|
||||||
|
"https://storage.googleapis.com/signed-url"
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||||
|
return_value=mock_storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
file = _make_file(storage_path="gcs://bucket/file.pdf")
|
||||||
|
response = await create_file_download_response(file)
|
||||||
|
assert response.status_code == 302
|
||||||
|
assert (
|
||||||
|
response.headers["location"] == "https://storage.googleapis.com/signed-url"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gcs_api_fallback_streams_directly(self, mocker):
|
||||||
|
from backend.api.features.workspace.routes import create_file_download_response
|
||||||
|
|
||||||
|
mock_storage = AsyncMock()
|
||||||
|
mock_storage.get_download_url.return_value = "/api/fallback"
|
||||||
|
mock_storage.retrieve.return_value = b"fallback content"
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||||
|
return_value=mock_storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
file = _make_file(storage_path="gcs://bucket/file.txt")
|
||||||
|
response = await create_file_download_response(file)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.body == b"fallback content"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gcs_signed_url_failure_falls_back_to_streaming(self, mocker):
|
||||||
|
from backend.api.features.workspace.routes import create_file_download_response
|
||||||
|
|
||||||
|
mock_storage = AsyncMock()
|
||||||
|
mock_storage.get_download_url.side_effect = RuntimeError("GCS error")
|
||||||
|
mock_storage.retrieve.return_value = b"streamed"
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||||
|
return_value=mock_storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
file = _make_file(storage_path="gcs://bucket/file.txt")
|
||||||
|
response = await create_file_download_response(file)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.body == b"streamed"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gcs_total_failure_raises(self, mocker):
|
||||||
|
from backend.api.features.workspace.routes import create_file_download_response
|
||||||
|
|
||||||
|
mock_storage = AsyncMock()
|
||||||
|
mock_storage.get_download_url.side_effect = RuntimeError("GCS error")
|
||||||
|
mock_storage.retrieve.side_effect = RuntimeError("Also failed")
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||||
|
return_value=mock_storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
file = _make_file(storage_path="gcs://bucket/file.txt")
|
||||||
|
with pytest.raises(RuntimeError, match="Also failed"):
|
||||||
|
await create_file_download_response(file)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from fastapi.routing import APIRoute
|
|||||||
from prisma.errors import PrismaError
|
from prisma.errors import PrismaError
|
||||||
|
|
||||||
import backend.api.features.admin.credit_admin_routes
|
import backend.api.features.admin.credit_admin_routes
|
||||||
|
import backend.api.features.admin.diagnostics_admin_routes
|
||||||
import backend.api.features.admin.execution_analytics_routes
|
import backend.api.features.admin.execution_analytics_routes
|
||||||
import backend.api.features.admin.platform_cost_routes
|
import backend.api.features.admin.platform_cost_routes
|
||||||
import backend.api.features.admin.rate_limit_admin_routes
|
import backend.api.features.admin.rate_limit_admin_routes
|
||||||
@@ -31,7 +32,9 @@ import backend.api.features.library.routes
|
|||||||
import backend.api.features.mcp.routes as mcp_routes
|
import backend.api.features.mcp.routes as mcp_routes
|
||||||
import backend.api.features.oauth
|
import backend.api.features.oauth
|
||||||
import backend.api.features.otto.routes
|
import backend.api.features.otto.routes
|
||||||
|
import backend.api.features.platform_linking.routes
|
||||||
import backend.api.features.postmark.postmark
|
import backend.api.features.postmark.postmark
|
||||||
|
import backend.api.features.push.routes as push_routes
|
||||||
import backend.api.features.store.model
|
import backend.api.features.store.model
|
||||||
import backend.api.features.store.routes
|
import backend.api.features.store.routes
|
||||||
import backend.api.features.v1
|
import backend.api.features.v1
|
||||||
@@ -39,6 +42,7 @@ import backend.api.features.workspace.routes as workspace_routes
|
|||||||
import backend.data.block
|
import backend.data.block
|
||||||
import backend.data.db
|
import backend.data.db
|
||||||
import backend.data.graph
|
import backend.data.graph
|
||||||
|
import backend.data.redis_client
|
||||||
import backend.data.user
|
import backend.data.user
|
||||||
import backend.integrations.webhooks.utils
|
import backend.integrations.webhooks.utils
|
||||||
import backend.util.service
|
import backend.util.service
|
||||||
@@ -93,6 +97,8 @@ async def lifespan_context(app: fastapi.FastAPI):
|
|||||||
verify_auth_settings()
|
verify_auth_settings()
|
||||||
|
|
||||||
await backend.data.db.connect()
|
await backend.data.db.connect()
|
||||||
|
# Eager connect to fail-fast if Redis is unreachable.
|
||||||
|
await backend.data.redis_client.get_redis_async()
|
||||||
|
|
||||||
# Configure thread pool for FastAPI sync operation performance
|
# Configure thread pool for FastAPI sync operation performance
|
||||||
# CRITICAL: FastAPI automatically runs ALL sync functions in this thread pool:
|
# CRITICAL: FastAPI automatically runs ALL sync functions in this thread pool:
|
||||||
@@ -144,7 +150,18 @@ async def lifespan_context(app: fastapi.FastAPI):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error shutting down workspace storage: {e}")
|
logger.warning(f"Error shutting down workspace storage: {e}")
|
||||||
|
|
||||||
await backend.data.db.disconnect()
|
# Each cleanup is wrapped so one failure doesn't block the rest. The
|
||||||
|
# Redis close in particular silences asyncio's "Unclosed ClusterNode"
|
||||||
|
# GC warning at interpreter shutdown.
|
||||||
|
try:
|
||||||
|
await backend.data.redis_client.disconnect_async()
|
||||||
|
except Exception:
|
||||||
|
logger.warning("redis_client.disconnect_async failed", exc_info=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await backend.data.db.disconnect()
|
||||||
|
except Exception:
|
||||||
|
logger.warning("db.disconnect failed", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
def custom_generate_unique_id(route: APIRoute):
|
def custom_generate_unique_id(route: APIRoute):
|
||||||
@@ -320,6 +337,11 @@ app.include_router(
|
|||||||
tags=["v2", "admin"],
|
tags=["v2", "admin"],
|
||||||
prefix="/api/credits",
|
prefix="/api/credits",
|
||||||
)
|
)
|
||||||
|
app.include_router(
|
||||||
|
backend.api.features.admin.diagnostics_admin_routes.router,
|
||||||
|
tags=["v2", "admin"],
|
||||||
|
prefix="/api",
|
||||||
|
)
|
||||||
app.include_router(
|
app.include_router(
|
||||||
backend.api.features.admin.execution_analytics_routes.router,
|
backend.api.features.admin.execution_analytics_routes.router,
|
||||||
tags=["v2", "admin"],
|
tags=["v2", "admin"],
|
||||||
@@ -372,6 +394,16 @@ app.include_router(
|
|||||||
tags=["oauth"],
|
tags=["oauth"],
|
||||||
prefix="/api/oauth",
|
prefix="/api/oauth",
|
||||||
)
|
)
|
||||||
|
app.include_router(
|
||||||
|
push_routes.router,
|
||||||
|
tags=["push"],
|
||||||
|
prefix="/api/push",
|
||||||
|
)
|
||||||
|
app.include_router(
|
||||||
|
backend.api.features.platform_linking.routes.router,
|
||||||
|
tags=["platform-linking"],
|
||||||
|
prefix="/api/platform-linking",
|
||||||
|
)
|
||||||
|
|
||||||
app.mount("/external-api", external_api)
|
app.mount("/external-api", external_api)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
@@ -17,14 +16,12 @@ from backend.api.model import (
|
|||||||
WSSubscribeGraphExecutionsRequest,
|
WSSubscribeGraphExecutionsRequest,
|
||||||
)
|
)
|
||||||
from backend.api.utils.cors import build_cors_params
|
from backend.api.utils.cors import build_cors_params
|
||||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
from backend.data import db, redis_client
|
||||||
from backend.data.notification_bus import AsyncRedisNotificationEventBus
|
|
||||||
from backend.data.user import DEFAULT_USER_ID
|
from backend.data.user import DEFAULT_USER_ID
|
||||||
from backend.monitoring.instrumentation import (
|
from backend.monitoring.instrumentation import (
|
||||||
instrument_fastapi,
|
instrument_fastapi,
|
||||||
update_websocket_connections,
|
update_websocket_connections,
|
||||||
)
|
)
|
||||||
from backend.util.retry import continuous_retry
|
|
||||||
from backend.util.service import AppProcess
|
from backend.util.service import AppProcess
|
||||||
from backend.util.settings import AppEnvironment, Config, Settings
|
from backend.util.settings import AppEnvironment, Config, Settings
|
||||||
|
|
||||||
@@ -34,10 +31,24 @@ settings = Settings()
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
manager = get_connection_manager()
|
# Prisma is needed to resolve graph_id from graph_exec_id on subscribe.
|
||||||
fut = asyncio.create_task(event_broadcaster(manager))
|
await db.connect()
|
||||||
fut.add_done_callback(lambda _: logger.info("Event broadcaster stopped"))
|
# Eager connect to fail-fast if Redis is unreachable.
|
||||||
yield
|
await redis_client.get_redis_async()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# Each cleanup is wrapped so one failure doesn't block the rest. The
|
||||||
|
# Redis close silences asyncio's "Unclosed ClusterNode" GC warning at
|
||||||
|
# interpreter shutdown.
|
||||||
|
try:
|
||||||
|
await redis_client.disconnect_async()
|
||||||
|
except Exception:
|
||||||
|
logger.warning("redis_client.disconnect_async failed", exc_info=True)
|
||||||
|
try:
|
||||||
|
await db.disconnect()
|
||||||
|
except Exception:
|
||||||
|
logger.warning("db.disconnect failed", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
|
docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
|
||||||
@@ -61,31 +72,6 @@ def get_connection_manager():
|
|||||||
return _connection_manager
|
return _connection_manager
|
||||||
|
|
||||||
|
|
||||||
@continuous_retry()
|
|
||||||
async def event_broadcaster(manager: ConnectionManager):
|
|
||||||
execution_bus = AsyncRedisExecutionEventBus()
|
|
||||||
notification_bus = AsyncRedisNotificationEventBus()
|
|
||||||
|
|
||||||
try:
|
|
||||||
|
|
||||||
async def execution_worker():
|
|
||||||
async for event in execution_bus.listen("*"):
|
|
||||||
await manager.send_execution_update(event)
|
|
||||||
|
|
||||||
async def notification_worker():
|
|
||||||
async for notification in notification_bus.listen("*"):
|
|
||||||
await manager.send_notification(
|
|
||||||
user_id=notification.user_id,
|
|
||||||
payload=notification.payload,
|
|
||||||
)
|
|
||||||
|
|
||||||
await asyncio.gather(execution_worker(), notification_worker())
|
|
||||||
finally:
|
|
||||||
# Ensure PubSub connections are closed on any exit to prevent leaks
|
|
||||||
await execution_bus.close()
|
|
||||||
await notification_bus.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||||
if not settings.config.enable_auth:
|
if not settings.config.enable_auth:
|
||||||
return DEFAULT_USER_ID
|
return DEFAULT_USER_ID
|
||||||
@@ -297,6 +283,21 @@ async def websocket_router(
|
|||||||
).model_dump_json()
|
).model_dump_json()
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(
|
||||||
|
"Subscription rejected for user #%s on '%s': %s",
|
||||||
|
user_id,
|
||||||
|
message.method.value,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
await websocket.send_text(
|
||||||
|
WSMessage(
|
||||||
|
method=WSMethod.ERROR,
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
).model_dump_json()
|
||||||
|
)
|
||||||
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error while handling '{message.method.value}' message "
|
f"Error while handling '{message.method.value}' message "
|
||||||
@@ -321,9 +322,13 @@ async def websocket_router(
|
|||||||
)
|
)
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
manager.disconnect_socket(websocket, user_id=user_id)
|
|
||||||
logger.debug("WebSocket client disconnected")
|
logger.debug("WebSocket client disconnected")
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"Unexpected error in websocket_router for user #{user_id}")
|
||||||
finally:
|
finally:
|
||||||
|
# Always release subscription pumps + Redis connections, regardless of how
|
||||||
|
# the loop exited — otherwise non-WebSocketDisconnect failures leak both.
|
||||||
|
await manager.disconnect_socket(websocket, user_id=user_id)
|
||||||
update_websocket_connections(user_id, -1)
|
update_websocket_connections(user_id, -1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -44,9 +44,12 @@ def test_websocket_server_uses_cors_helper(mocker) -> None:
|
|||||||
"backend.api.ws_api.build_cors_params", return_value=cors_params
|
"backend.api.ws_api.build_cors_params", return_value=cors_params
|
||||||
)
|
)
|
||||||
|
|
||||||
with override_config(
|
with (
|
||||||
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
|
override_config(
|
||||||
), override_config(settings, "app_env", AppEnvironment.LOCAL):
|
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
|
||||||
|
),
|
||||||
|
override_config(settings, "app_env", AppEnvironment.LOCAL),
|
||||||
|
):
|
||||||
WebsocketServer().run()
|
WebsocketServer().run()
|
||||||
|
|
||||||
build_cors.assert_called_once_with(
|
build_cors.assert_called_once_with(
|
||||||
@@ -65,9 +68,12 @@ def test_websocket_server_uses_cors_helper(mocker) -> None:
|
|||||||
def test_websocket_server_blocks_localhost_in_production(mocker) -> None:
|
def test_websocket_server_blocks_localhost_in_production(mocker) -> None:
|
||||||
mocker.patch("backend.api.ws_api.uvicorn.run")
|
mocker.patch("backend.api.ws_api.uvicorn.run")
|
||||||
|
|
||||||
with override_config(
|
with (
|
||||||
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
|
override_config(
|
||||||
), override_config(settings, "app_env", AppEnvironment.PRODUCTION):
|
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
|
||||||
|
),
|
||||||
|
override_config(settings, "app_env", AppEnvironment.PRODUCTION),
|
||||||
|
):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
WebsocketServer().run()
|
WebsocketServer().run()
|
||||||
|
|
||||||
@@ -290,7 +296,232 @@ async def test_handle_unsubscribe_missing_data(
|
|||||||
message=message,
|
message=message,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_manager._unsubscribe.assert_not_called()
|
mock_manager.unsubscribe_graph_exec.assert_not_called()
|
||||||
mock_websocket.send_text.assert_called_once()
|
mock_websocket.send_text.assert_called_once()
|
||||||
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
|
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
|
||||||
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
|
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Per-graph subscribe branch ----------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_subscribe_graph_execs_branch(
|
||||||
|
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
"""The SUBSCRIBE_GRAPH_EXECS branch must route to subscribe_graph_execs,
|
||||||
|
not subscribe_graph_exec — regression guard for the aggregate channel."""
|
||||||
|
message = WSMessage(
|
||||||
|
method=WSMethod.SUBSCRIBE_GRAPH_EXECS,
|
||||||
|
data={"graph_id": "graph-abc"},
|
||||||
|
)
|
||||||
|
mock_manager.subscribe_graph_execs.return_value = (
|
||||||
|
"user-1|graph#graph-abc|executions"
|
||||||
|
)
|
||||||
|
|
||||||
|
await handle_subscribe(
|
||||||
|
connection_manager=cast(ConnectionManager, mock_manager),
|
||||||
|
websocket=cast(WebSocket, mock_websocket),
|
||||||
|
user_id="user-1",
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_manager.subscribe_graph_execs.assert_called_once_with(
|
||||||
|
user_id="user-1",
|
||||||
|
graph_id="graph-abc",
|
||||||
|
websocket=mock_websocket,
|
||||||
|
)
|
||||||
|
mock_manager.subscribe_graph_exec.assert_not_called()
|
||||||
|
mock_websocket.send_text.assert_called_once()
|
||||||
|
assert (
|
||||||
|
'"method":"subscribe_graph_executions"'
|
||||||
|
in mock_websocket.send_text.call_args[0][0]
|
||||||
|
)
|
||||||
|
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_subscribe_rejects_unrelated_method(
|
||||||
|
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
"""handle_subscribe must raise for methods that aren't SUBSCRIBE_*."""
|
||||||
|
import pytest as _pytest
|
||||||
|
|
||||||
|
message = WSMessage(
|
||||||
|
method=WSMethod.HEARTBEAT,
|
||||||
|
data={"graph_exec_id": "x"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with _pytest.raises(ValueError):
|
||||||
|
await handle_subscribe(
|
||||||
|
connection_manager=cast(ConnectionManager, mock_manager),
|
||||||
|
websocket=cast(WebSocket, mock_websocket),
|
||||||
|
user_id="user-1",
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- authenticate_websocket branches ----------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_websocket_missing_token_closes_4001(mocker) -> None:
|
||||||
|
from backend.api.ws_api import authenticate_websocket
|
||||||
|
|
||||||
|
mocker.patch.object(settings.config, "enable_auth", True)
|
||||||
|
ws = AsyncMock(spec=WebSocket)
|
||||||
|
ws.query_params = {}
|
||||||
|
|
||||||
|
user_id = await authenticate_websocket(ws)
|
||||||
|
|
||||||
|
ws.close.assert_awaited_once()
|
||||||
|
assert ws.close.call_args.kwargs["code"] == 4001
|
||||||
|
assert user_id == ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_websocket_invalid_token_closes_4003(mocker) -> None:
|
||||||
|
from backend.api.ws_api import authenticate_websocket
|
||||||
|
|
||||||
|
mocker.patch.object(settings.config, "enable_auth", True)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.ws_api.parse_jwt_token", side_effect=ValueError("bad token")
|
||||||
|
)
|
||||||
|
ws = AsyncMock(spec=WebSocket)
|
||||||
|
ws.query_params = {"token": "abc"}
|
||||||
|
|
||||||
|
user_id = await authenticate_websocket(ws)
|
||||||
|
|
||||||
|
ws.close.assert_awaited_once()
|
||||||
|
assert ws.close.call_args.kwargs["code"] == 4003
|
||||||
|
assert user_id == ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_websocket_missing_sub_closes_4002(mocker) -> None:
|
||||||
|
from backend.api.ws_api import authenticate_websocket
|
||||||
|
|
||||||
|
mocker.patch.object(settings.config, "enable_auth", True)
|
||||||
|
mocker.patch("backend.api.ws_api.parse_jwt_token", return_value={"not_sub": "x"})
|
||||||
|
ws = AsyncMock(spec=WebSocket)
|
||||||
|
ws.query_params = {"token": "abc"}
|
||||||
|
|
||||||
|
user_id = await authenticate_websocket(ws)
|
||||||
|
|
||||||
|
ws.close.assert_awaited_once()
|
||||||
|
assert ws.close.call_args.kwargs["code"] == 4002
|
||||||
|
assert user_id == ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_websocket_happy_path_returns_sub(mocker) -> None:
|
||||||
|
from backend.api.ws_api import authenticate_websocket
|
||||||
|
|
||||||
|
mocker.patch.object(settings.config, "enable_auth", True)
|
||||||
|
mocker.patch("backend.api.ws_api.parse_jwt_token", return_value={"sub": "user-X"})
|
||||||
|
ws = AsyncMock(spec=WebSocket)
|
||||||
|
ws.query_params = {"token": "abc"}
|
||||||
|
|
||||||
|
user_id = await authenticate_websocket(ws)
|
||||||
|
|
||||||
|
assert user_id == "user-X"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_websocket_auth_disabled_returns_default(mocker) -> None:
|
||||||
|
from backend.api.ws_api import authenticate_websocket
|
||||||
|
|
||||||
|
mocker.patch.object(settings.config, "enable_auth", False)
|
||||||
|
ws = AsyncMock(spec=WebSocket)
|
||||||
|
ws.query_params = {}
|
||||||
|
|
||||||
|
user_id = await authenticate_websocket(ws)
|
||||||
|
|
||||||
|
assert user_id == DEFAULT_USER_ID
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- get_connection_manager singleton ----------
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_connection_manager_singleton() -> None:
|
||||||
|
"""Repeated calls must return the same ConnectionManager — the WS router
|
||||||
|
depends on a single process-wide subscription table."""
|
||||||
|
import backend.api.ws_api as ws_api
|
||||||
|
|
||||||
|
ws_api._connection_manager = None
|
||||||
|
a = ws_api.get_connection_manager()
|
||||||
|
b = ws_api.get_connection_manager()
|
||||||
|
assert a is b
|
||||||
|
assert isinstance(a, ConnectionManager)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Lifespan: Prisma connect/disconnect ----------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lifespan_connects_and_disconnects_prisma(mocker) -> None:
|
||||||
|
"""Lifespan must both connect() and disconnect() db — the subscribe path
|
||||||
|
resolves graph_id via Prisma so a missing connect() is the regression bug."""
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from backend.api.ws_api import lifespan
|
||||||
|
|
||||||
|
mock_db = mocker.patch("backend.api.ws_api.db")
|
||||||
|
mock_db.connect = AsyncMock()
|
||||||
|
mock_db.disconnect = AsyncMock()
|
||||||
|
|
||||||
|
dummy_app = FastAPI()
|
||||||
|
async with lifespan(dummy_app):
|
||||||
|
mock_db.connect.assert_awaited_once()
|
||||||
|
mock_db.disconnect.assert_not_called()
|
||||||
|
mock_db.disconnect.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lifespan_still_disconnects_on_exception(mocker) -> None:
|
||||||
|
"""If the app raises inside the yield, Prisma must still disconnect."""
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from backend.api.ws_api import lifespan
|
||||||
|
|
||||||
|
mock_db = mocker.patch("backend.api.ws_api.db")
|
||||||
|
mock_db.connect = AsyncMock()
|
||||||
|
mock_db.disconnect = AsyncMock()
|
||||||
|
|
||||||
|
dummy_app = FastAPI()
|
||||||
|
|
||||||
|
class _Boom(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(_Boom):
|
||||||
|
async with lifespan(dummy_app):
|
||||||
|
raise _Boom()
|
||||||
|
|
||||||
|
mock_db.disconnect.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Health endpoint ----------
|
||||||
|
|
||||||
|
|
||||||
|
def test_health_endpoint_returns_ok() -> None:
|
||||||
|
# TestClient triggers lifespan — stub it out so Prisma isn't hit.
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
import backend.api.ws_api as ws_api
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _noop_lifespan(app):
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Replace the app-level lifespan temporarily.
|
||||||
|
real_router_lifespan = ws_api.app.router.lifespan_context
|
||||||
|
ws_api.app.router.lifespan_context = _noop_lifespan
|
||||||
|
try:
|
||||||
|
with TestClient(ws_api.app) as client:
|
||||||
|
r = client.get("/")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {"status": "healthy"}
|
||||||
|
finally:
|
||||||
|
ws_api.app.router.lifespan_context = real_router_lifespan
|
||||||
|
|||||||
@@ -38,19 +38,23 @@ def main(**kwargs):
|
|||||||
|
|
||||||
from backend.api.rest_api import AgentServer
|
from backend.api.rest_api import AgentServer
|
||||||
from backend.api.ws_api import WebsocketServer
|
from backend.api.ws_api import WebsocketServer
|
||||||
|
from backend.copilot.bot.app import CoPilotChatBridge
|
||||||
from backend.copilot.executor.manager import CoPilotExecutor
|
from backend.copilot.executor.manager import CoPilotExecutor
|
||||||
from backend.data.db_manager import DatabaseManager
|
from backend.data.db_manager import DatabaseManager
|
||||||
from backend.executor import ExecutionManager, Scheduler
|
from backend.executor import ExecutionManager, Scheduler
|
||||||
from backend.notifications import NotificationManager
|
from backend.notifications import NotificationManager
|
||||||
|
from backend.platform_linking.manager import PlatformLinkingManager
|
||||||
|
|
||||||
run_processes(
|
run_processes(
|
||||||
DatabaseManager().set_log_level("warning"),
|
DatabaseManager().set_log_level("warning"),
|
||||||
Scheduler(),
|
Scheduler(),
|
||||||
NotificationManager(),
|
NotificationManager(),
|
||||||
|
PlatformLinkingManager(),
|
||||||
WebsocketServer(),
|
WebsocketServer(),
|
||||||
AgentServer(),
|
AgentServer(),
|
||||||
ExecutionManager(),
|
ExecutionManager(),
|
||||||
CoPilotExecutor(),
|
CoPilotExecutor(),
|
||||||
|
CoPilotChatBridge(),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -96,27 +96,64 @@ class BlockCategory(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class BlockCostType(str, Enum):
|
class BlockCostType(str, Enum):
|
||||||
RUN = "run" # cost X credits per run
|
# RUN : cost_amount credits per run.
|
||||||
BYTE = "byte" # cost X credits per byte
|
# BYTE : cost_amount credits per byte of input data.
|
||||||
SECOND = "second" # cost X credits per second
|
# SECOND : cost_amount credits per cost_divisor walltime seconds.
|
||||||
|
# ITEMS : cost_amount credits per cost_divisor items (from stats).
|
||||||
|
# COST_USD : cost_amount credits per USD of stats.provider_cost.
|
||||||
|
# TOKENS : per-(model, provider) rate table lookup; see TOKEN_COST.
|
||||||
|
RUN = "run"
|
||||||
|
BYTE = "byte"
|
||||||
|
SECOND = "second"
|
||||||
|
ITEMS = "items"
|
||||||
|
COST_USD = "cost_usd"
|
||||||
|
TOKENS = "tokens"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_dynamic(self) -> bool:
|
||||||
|
"""Real charge is computed post-flight from stats.
|
||||||
|
|
||||||
|
Dynamic types (SECOND/ITEMS/COST_USD/TOKENS) return 0 pre-flight and
|
||||||
|
settle against stats via charge_reconciled_usage once the block runs.
|
||||||
|
"""
|
||||||
|
return self in _DYNAMIC_COST_TYPES
|
||||||
|
|
||||||
|
|
||||||
|
_DYNAMIC_COST_TYPES: frozenset[BlockCostType] = frozenset(
|
||||||
|
{
|
||||||
|
BlockCostType.SECOND,
|
||||||
|
BlockCostType.ITEMS,
|
||||||
|
BlockCostType.COST_USD,
|
||||||
|
BlockCostType.TOKENS,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BlockCost(BaseModel):
|
class BlockCost(BaseModel):
|
||||||
cost_amount: int
|
cost_amount: int
|
||||||
cost_filter: BlockInput
|
cost_filter: BlockInput
|
||||||
cost_type: BlockCostType
|
cost_type: BlockCostType
|
||||||
|
# cost_divisor: interpret cost_amount as "credits per cost_divisor units".
|
||||||
|
# Only meaningful for SECOND / ITEMS. TOKENS routes through TOKEN_COST
|
||||||
|
# rate tables (per-model input/output/cache pricing) and ignores
|
||||||
|
# cost_divisor entirely. Defaults to 1 so existing RUN/BYTE entries stay
|
||||||
|
# point-wise. Example: cost_amount=1, cost_divisor=10 under SECOND means
|
||||||
|
# "1 credit per 10 seconds of walltime".
|
||||||
|
cost_divisor: int = 1
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
cost_amount: int,
|
cost_amount: int,
|
||||||
cost_type: BlockCostType = BlockCostType.RUN,
|
cost_type: BlockCostType = BlockCostType.RUN,
|
||||||
cost_filter: Optional[BlockInput] = None,
|
cost_filter: Optional[BlockInput] = None,
|
||||||
|
cost_divisor: int = 1,
|
||||||
**data: Any,
|
**data: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
cost_amount=cost_amount,
|
cost_amount=cost_amount,
|
||||||
cost_filter=cost_filter or {},
|
cost_filter=cost_filter or {},
|
||||||
cost_type=cost_type,
|
cost_type=cost_type,
|
||||||
|
cost_divisor=max(1, cost_divisor),
|
||||||
**data,
|
**data,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -168,9 +205,31 @@ class BlockSchema(BaseModel):
|
|||||||
return cls.cached_jsonschema
|
return cls.cached_jsonschema
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_data(cls, data: BlockInput) -> str | None:
|
def validate_data(
|
||||||
|
cls,
|
||||||
|
data: BlockInput,
|
||||||
|
exclude_fields: set[str] | None = None,
|
||||||
|
) -> str | None:
|
||||||
|
schema = cls.jsonschema()
|
||||||
|
if exclude_fields:
|
||||||
|
# Drop the excluded fields from both the properties and the
|
||||||
|
# ``required`` list so jsonschema doesn't flag them as missing.
|
||||||
|
# Used by the dry-run path to skip credentials validation while
|
||||||
|
# still validating the remaining block inputs.
|
||||||
|
schema = {
|
||||||
|
**schema,
|
||||||
|
"properties": {
|
||||||
|
k: v
|
||||||
|
for k, v in schema.get("properties", {}).items()
|
||||||
|
if k not in exclude_fields
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
r for r in schema.get("required", []) if r not in exclude_fields
|
||||||
|
],
|
||||||
|
}
|
||||||
|
data = {k: v for k, v in data.items() if k not in exclude_fields}
|
||||||
return json.validate_with_jsonschema(
|
return json.validate_with_jsonschema(
|
||||||
schema=cls.jsonschema(),
|
schema=schema,
|
||||||
data={k: v for k, v in data.items() if v is not None},
|
data={k: v for k, v in data.items() if v is not None},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -311,6 +370,8 @@ class BlockSchema(BaseModel):
|
|||||||
"credentials_provider": [config.get("provider", "google")],
|
"credentials_provider": [config.get("provider", "google")],
|
||||||
"credentials_types": [config.get("type", "oauth2")],
|
"credentials_types": [config.get("type", "oauth2")],
|
||||||
"credentials_scopes": config.get("scopes"),
|
"credentials_scopes": config.get("scopes"),
|
||||||
|
"is_auto_credential": True,
|
||||||
|
"input_field_name": info["field_name"],
|
||||||
}
|
}
|
||||||
result[kwarg_name] = CredentialsFieldInfo.model_validate(
|
result[kwarg_name] = CredentialsFieldInfo.model_validate(
|
||||||
auto_schema, by_alias=True
|
auto_schema, by_alias=True
|
||||||
@@ -421,19 +482,6 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
|
|||||||
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||||
_optimized_description: ClassVar[str | None] = None
|
_optimized_description: ClassVar[str | None] = None
|
||||||
|
|
||||||
def extra_runtime_cost(self, execution_stats: NodeExecutionStats) -> int:
|
|
||||||
"""Return extra runtime cost to charge after this block run completes.
|
|
||||||
|
|
||||||
Called by the executor after a block finishes with COMPLETED status.
|
|
||||||
The return value is the number of additional base-cost credits to
|
|
||||||
charge beyond the single credit already collected by charge_usage
|
|
||||||
at the start of execution. Defaults to 0 (no extra charges).
|
|
||||||
|
|
||||||
Override in blocks (e.g. OrchestratorBlock) that make multiple LLM
|
|
||||||
calls within one run and should be billed per call.
|
|
||||||
"""
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
id: str = "",
|
id: str = "",
|
||||||
@@ -717,11 +765,16 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|||||||
# (e.g. AgentExecutorBlock) get proper input validation.
|
# (e.g. AgentExecutorBlock) get proper input validation.
|
||||||
is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False)
|
is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False)
|
||||||
if is_dry_run:
|
if is_dry_run:
|
||||||
|
# Credential fields may be absent (LLM-built agents often skip
|
||||||
|
# wiring them) or nullified earlier in the pipeline. Validate
|
||||||
|
# the non-credential inputs against a schema with those fields
|
||||||
|
# excluded — stripping only the data while keeping them in the
|
||||||
|
# ``required`` list would falsely report ``'credentials' is a
|
||||||
|
# required property``.
|
||||||
cred_field_names = set(self.input_schema.get_credentials_fields().keys())
|
cred_field_names = set(self.input_schema.get_credentials_fields().keys())
|
||||||
non_cred_data = {
|
if error := self.input_schema.validate_data(
|
||||||
k: v for k, v in input_data.items() if k not in cred_field_names
|
input_data, exclude_fields=cred_field_names
|
||||||
}
|
):
|
||||||
if error := self.input_schema.validate_data(non_cred_data):
|
|
||||||
raise BlockInputError(
|
raise BlockInputError(
|
||||||
message=f"Unable to execute block with invalid input data: {error}",
|
message=f"Unable to execute block with invalid input data: {error}",
|
||||||
block_name=self.name,
|
block_name=self.name,
|
||||||
@@ -735,6 +788,61 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|||||||
block_id=self.id,
|
block_id=self.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Ensure auto-credential kwargs are present before we hand off to
|
||||||
|
# run(). A missing auto-credential means the upstream field (e.g.
|
||||||
|
# a Google Drive picker) didn't embed a _credentials_id, or the
|
||||||
|
# executor couldn't resolve it. Without this guard, run() would
|
||||||
|
# crash with a TypeError (missing required kwarg) or an opaque
|
||||||
|
# AttributeError deep inside the provider SDK.
|
||||||
|
#
|
||||||
|
# Only raise when the field is ALSO not populated in input_data.
|
||||||
|
# ``_acquire_auto_credentials`` intentionally skips setting the
|
||||||
|
# kwarg in two legitimate cases — ``_credentials_id`` is ``None``
|
||||||
|
# (chained from upstream) or the field is missing from
|
||||||
|
# ``input_data`` at prep time (connected from upstream block).
|
||||||
|
# In both cases the upstream block is expected to populate the
|
||||||
|
# field value by execute time; raising here would break the
|
||||||
|
# documented ``AgentGoogleDriveFileInputBlock`` chaining pattern.
|
||||||
|
# Dry-run skips because the executor intentionally runs blocks
|
||||||
|
# without resolved creds for schema validation.
|
||||||
|
if not is_dry_run:
|
||||||
|
for (
|
||||||
|
kwarg_name,
|
||||||
|
info,
|
||||||
|
) in self.input_schema.get_auto_credentials_fields().items():
|
||||||
|
kwargs.setdefault(kwarg_name, None)
|
||||||
|
if kwargs[kwarg_name] is not None:
|
||||||
|
continue
|
||||||
|
# Upstream-chained pattern: the field was populated by a
|
||||||
|
# prior node (e.g. AgentGoogleDriveFileInputBlock) whose
|
||||||
|
# output carries a resolved ``_credentials_id``.
|
||||||
|
# ``_acquire_auto_credentials`` deliberately doesn't set
|
||||||
|
# the kwarg in that case because the value isn't available
|
||||||
|
# at prep time; the executor fills it in before we reach
|
||||||
|
# ``_execute``. Trust it if the ``_credentials_id`` KEY
|
||||||
|
# is present — its value may be explicitly ``None`` in
|
||||||
|
# the chained case (see sentry thread
|
||||||
|
# PRRT_kwDOJKSTjM58sJfA). Checking truthiness here would
|
||||||
|
# falsely preempt run() for every valid chained graph
|
||||||
|
# that ships ``_credentials_id=None`` in the picker
|
||||||
|
# object. Mirror ``_acquire_auto_credentials``'s own
|
||||||
|
# skip rule, which treats ``cred_id is None`` as a
|
||||||
|
# chained-skip signal.
|
||||||
|
field_name = info["field_name"]
|
||||||
|
field_value = input_data.get(field_name)
|
||||||
|
if isinstance(field_value, dict) and "_credentials_id" in field_value:
|
||||||
|
continue
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=(
|
||||||
|
f"Missing credentials for '{kwarg_name}'. "
|
||||||
|
"Select a file via the picker (which carries "
|
||||||
|
"its credentials), or connect credentials for "
|
||||||
|
"this block."
|
||||||
|
),
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=self.id,
|
||||||
|
)
|
||||||
|
|
||||||
# Use the validated input data
|
# Use the validated input data
|
||||||
async for output_name, output_data in self.run(
|
async for output_name, output_data in self.run(
|
||||||
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
|
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
|
||||||
|
|||||||
@@ -0,0 +1,56 @@
|
|||||||
|
"""Provider descriptions for services that don't yet have their own ``_config.py``.
|
||||||
|
|
||||||
|
Every provider in ``_STATIC_PROVIDER_CONFIGS`` below is declared here because
|
||||||
|
its block code currently lives either in a single shared file (e.g. the 8 LLM
|
||||||
|
providers in ``blocks/llm.py``) or in a single-file block that has no dedicated
|
||||||
|
directory (e.g. ``blocks/reddit.py``).
|
||||||
|
|
||||||
|
This file gets loaded by the block auto-loader in ``blocks/__init__.py``
|
||||||
|
(``rglob("*.py")`` picks it up) so the ``ProviderBuilder(...).build()`` calls
|
||||||
|
run at startup and populate ``AutoRegistry`` before the first API request.
|
||||||
|
|
||||||
|
**Migration path:** when a provider graduates into its own directory with a
|
||||||
|
proper ``_config.py`` (following the SDK pattern, e.g. ``blocks/linear/_config.py``),
|
||||||
|
delete its entry here. The metadata will still be served by
|
||||||
|
``GET /integrations/providers`` — it just moves to live next to the provider's
|
||||||
|
auth and webhook config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from backend.data.model import CredentialsType
|
||||||
|
from backend.sdk import ProviderBuilder
|
||||||
|
|
||||||
|
_STATIC_PROVIDER_CONFIGS: dict[str, tuple[str, tuple[CredentialsType, ...]]] = {
|
||||||
|
# LLM providers that share blocks/llm.py
|
||||||
|
"aiml_api": ("Unified access to 100+ AI models", ("api_key",)),
|
||||||
|
"anthropic": ("Claude language models", ("api_key",)),
|
||||||
|
"groq": ("Fast LLM inference", ("api_key",)),
|
||||||
|
"llama_api": ("Llama model hosting", ("api_key",)),
|
||||||
|
"ollama": ("Run open-source LLMs locally", ("api_key",)),
|
||||||
|
"open_router": ("One API for every LLM", ("api_key",)),
|
||||||
|
"openai": ("GPT models and embeddings", ("api_key",)),
|
||||||
|
"v0": ("AI-generated UI components", ("api_key",)),
|
||||||
|
# Single-file providers (one provider per standalone blocks/*.py file)
|
||||||
|
"d_id": ("AI avatar and video generation", ("api_key",)),
|
||||||
|
"e2b": ("Sandboxed code execution", ("api_key",)),
|
||||||
|
"google_maps": ("Places, directions, geocoding", ("api_key",)),
|
||||||
|
"http": ("Generic HTTP requests", ("api_key", "host_scoped")),
|
||||||
|
"ideogram": ("Text-to-image generation", ("api_key",)),
|
||||||
|
"medium": ("Publish stories and posts", ("api_key",)),
|
||||||
|
"mem0": ("Long-term memory for agents", ("api_key",)),
|
||||||
|
"openweathermap": ("Weather data and forecasts", ("api_key",)),
|
||||||
|
"pinecone": ("Managed vector database", ("api_key",)),
|
||||||
|
"reddit": ("Subreddits, posts, and comments", ("oauth2",)),
|
||||||
|
"revid": ("AI-generated short-form video", ("api_key",)),
|
||||||
|
"screenshotone": ("Automated website screenshots", ("api_key",)),
|
||||||
|
"smtp": ("Send email via SMTP", ("user_password",)),
|
||||||
|
"unreal_speech": ("Low-cost text-to-speech", ("api_key",)),
|
||||||
|
"webshare_proxy": ("Rotating proxies for scraping", ("api_key",)),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _name, (_description, _auth_types) in _STATIC_PROVIDER_CONFIGS.items():
|
||||||
|
(
|
||||||
|
ProviderBuilder(_name)
|
||||||
|
.with_description(_description)
|
||||||
|
.with_supported_auth_types(*_auth_types)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
@@ -171,7 +171,10 @@ class AgentExecutorBlock(Block):
|
|||||||
)
|
)
|
||||||
self.merge_stats(
|
self.merge_stats(
|
||||||
NodeExecutionStats(
|
NodeExecutionStats(
|
||||||
extra_cost=event.stats.cost if event.stats else 0,
|
# Sub-graph already debited each of its own nodes; we
|
||||||
|
# roll up its total so graph_stats.cost reflects the
|
||||||
|
# full sub-graph spend.
|
||||||
|
reconciled_cost_delta=(event.stats.cost if event.stats else 0),
|
||||||
extra_steps=event.stats.node_exec_count if event.stats else 0,
|
extra_steps=event.stats.node_exec_count if event.stats else 0,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,11 +4,17 @@ Shared configuration for all AgentMail blocks.
|
|||||||
|
|
||||||
from agentmail import AsyncAgentMail
|
from agentmail import AsyncAgentMail
|
||||||
|
|
||||||
from backend.sdk import APIKeyCredentials, ProviderBuilder, SecretStr
|
from backend.sdk import APIKeyCredentials, BlockCostType, ProviderBuilder, SecretStr
|
||||||
|
|
||||||
|
# AgentMail is in beta with no published paid tier yet, but ~37 blocks
|
||||||
|
# without any BLOCK_COSTS entry means they currently execute wallet-free.
|
||||||
|
# 1 cr/call is a conservative interim floor so no AgentMail work leaks
|
||||||
|
# past billing. Revisit once AgentMail publishes usage-based pricing.
|
||||||
agent_mail = (
|
agent_mail = (
|
||||||
ProviderBuilder("agent_mail")
|
ProviderBuilder("agent_mail")
|
||||||
|
.with_description("Managed email accounts for agents")
|
||||||
.with_api_key("AGENTMAIL_API_KEY", "AgentMail API Key")
|
.with_api_key("AGENTMAIL_API_KEY", "AgentMail API Key")
|
||||||
|
.with_base_cost(1, BlockCostType.RUN)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from ._webhook import AirtableWebhookManager
|
|||||||
# Configure the Airtable provider with API key authentication
|
# Configure the Airtable provider with API key authentication
|
||||||
airtable = (
|
airtable = (
|
||||||
ProviderBuilder("airtable")
|
ProviderBuilder("airtable")
|
||||||
|
.with_description("Bases, tables, and records")
|
||||||
.with_api_key("AIRTABLE_API_KEY", "Airtable Personal Access Token")
|
.with_api_key("AIRTABLE_API_KEY", "Airtable Personal Access Token")
|
||||||
.with_webhook_manager(AirtableWebhookManager)
|
.with_webhook_manager(AirtableWebhookManager)
|
||||||
.with_base_cost(1, BlockCostType.RUN)
|
.with_base_cost(1, BlockCostType.RUN)
|
||||||
|
|||||||
15
autogpt_platform/backend/backend/blocks/apollo/_config.py
Normal file
15
autogpt_platform/backend/backend/blocks/apollo/_config.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""Provider registration for Apollo.
|
||||||
|
|
||||||
|
Registers the provider description shown in the settings integrations UI.
|
||||||
|
Apollo doesn't use a full :class:`ProviderBuilder` chain (auth is set up in
|
||||||
|
``_auth.py``), so this file only declares metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from backend.sdk import ProviderBuilder
|
||||||
|
|
||||||
|
apollo = (
|
||||||
|
ProviderBuilder("apollo")
|
||||||
|
.with_description("Sales intelligence and prospecting")
|
||||||
|
.with_supported_auth_types("api_key")
|
||||||
|
.build()
|
||||||
|
)
|
||||||
@@ -7,6 +7,7 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from pydantic import field_validator
|
||||||
from typing_extensions import TypedDict # Needed for Python <3.12 compatibility
|
from typing_extensions import TypedDict # Needed for Python <3.12 compatibility
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.blocks._base import (
|
||||||
@@ -17,12 +18,14 @@ from backend.blocks._base import (
|
|||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.copilot.permissions import (
|
from backend.copilot.permissions import (
|
||||||
|
DISABLED_LEGACY_TOOL_NAMES,
|
||||||
CopilotPermissions,
|
CopilotPermissions,
|
||||||
ToolName,
|
ToolName,
|
||||||
all_known_tool_names,
|
all_known_tool_names,
|
||||||
validate_block_identifiers,
|
validate_block_identifiers,
|
||||||
)
|
)
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
@@ -32,9 +35,36 @@ logger = logging.getLogger(__name__)
|
|||||||
# Block ID shared between autopilot.py and copilot prompting.py.
|
# Block ID shared between autopilot.py and copilot prompting.py.
|
||||||
AUTOPILOT_BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
AUTOPILOT_BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||||
|
|
||||||
|
# Identifiers used when registering an AutoPilotBlock turn with the
|
||||||
|
# stream registry — distinguishes block-originated turns from sub-session
|
||||||
|
# or HTTP SSE turns in logs / observability.
|
||||||
|
_AUTOPILOT_TOOL_CALL_ID = "autopilot_block"
|
||||||
|
_AUTOPILOT_TOOL_NAME = "autopilot_block"
|
||||||
|
|
||||||
class SubAgentRecursionError(RuntimeError):
|
# Ceiling on how long AutoPilotBlock.execute_copilot will wait for the
|
||||||
"""Raised when the sub-agent nesting depth limit is exceeded."""
|
# enqueued turn's terminal event. Graph blocks run synchronously from
|
||||||
|
# the caller's perspective so we wait effectively as long as needed; 6h
|
||||||
|
# matches the previous abandoned-task cap and is much longer than any
|
||||||
|
# legitimate AutoPilot turn.
|
||||||
|
_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS = 6 * 60 * 60 # 6 hours
|
||||||
|
|
||||||
|
|
||||||
|
class SubAgentRecursionError(BlockExecutionError):
|
||||||
|
"""Raised when the AutoPilot sub-agent nesting depth limit is exceeded.
|
||||||
|
|
||||||
|
Inherits :class:`BlockExecutionError` — this is a known, handled
|
||||||
|
runtime failure at the block level (caller nested AutoPilotBlocks
|
||||||
|
beyond the configured limit). Surfaces with the block_name /
|
||||||
|
block_id the block framework expects, instead of being wrapped in
|
||||||
|
``BlockUnknownError``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
message=message,
|
||||||
|
block_name="AutoPilotBlock",
|
||||||
|
block_id=AUTOPILOT_BLOCK_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolCallEntry(TypedDict):
|
class ToolCallEntry(TypedDict):
|
||||||
@@ -170,6 +200,13 @@ class AutoPilotBlock(Block):
|
|||||||
# timeouts internally; wrapping with asyncio.timeout corrupts the
|
# timeouts internally; wrapping with asyncio.timeout corrupts the
|
||||||
# SDK's internal stream (see service.py CRITICAL comment).
|
# SDK's internal stream (see service.py CRITICAL comment).
|
||||||
|
|
||||||
|
@field_validator("tools", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def strip_disabled_legacy_tools(cls, tools: Any) -> Any:
|
||||||
|
if not isinstance(tools, list):
|
||||||
|
return tools
|
||||||
|
return [tool for tool in tools if tool not in DISABLED_LEGACY_TOOL_NAMES]
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
"""Output schema for the AutoPilot block."""
|
"""Output schema for the AutoPilot block."""
|
||||||
|
|
||||||
@@ -268,11 +305,15 @@ class AutoPilotBlock(Block):
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
permissions: "CopilotPermissions | None" = None,
|
permissions: "CopilotPermissions | None" = None,
|
||||||
) -> tuple[str, list[ToolCallEntry], str, str, TokenUsage]:
|
) -> tuple[str, list[ToolCallEntry], str, str, TokenUsage]:
|
||||||
"""Invoke the copilot and collect all stream results.
|
"""Invoke the copilot on the copilot_executor queue and aggregate the
|
||||||
|
result.
|
||||||
|
|
||||||
Delegates to :func:`collect_copilot_response` — the shared helper that
|
Delegates to :func:`run_copilot_turn_via_queue` — the shared
|
||||||
consumes ``stream_chat_completion_sdk`` without wrapping it in an
|
primitive used by ``run_sub_session`` too — which creates the
|
||||||
``asyncio.timeout`` (the SDK manages its own heartbeat-based timeouts).
|
stream_registry meta record, enqueues the job, and waits on the
|
||||||
|
Redis stream for the terminal event. Any available
|
||||||
|
copilot_executor worker picks up the job, so this call survives
|
||||||
|
the graph-executor worker dying mid-turn (RabbitMQ redelivers).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: The user task/instruction.
|
prompt: The user task/instruction.
|
||||||
@@ -285,8 +326,8 @@ class AutoPilotBlock(Block):
|
|||||||
Returns:
|
Returns:
|
||||||
A tuple of (response_text, tool_calls, history_json, session_id, usage).
|
A tuple of (response_text, tool_calls, history_json, session_id, usage).
|
||||||
"""
|
"""
|
||||||
from backend.copilot.sdk.collect import (
|
from backend.copilot.sdk.session_waiter import (
|
||||||
collect_copilot_response, # avoid circular import
|
run_copilot_turn_via_queue, # avoid circular import
|
||||||
)
|
)
|
||||||
|
|
||||||
tokens = _check_recursion(max_recursion_depth)
|
tokens = _check_recursion(max_recursion_depth)
|
||||||
@@ -299,14 +340,35 @@ class AutoPilotBlock(Block):
|
|||||||
if system_context:
|
if system_context:
|
||||||
effective_prompt = f"[System Context: {system_context}]\n\n{prompt}"
|
effective_prompt = f"[System Context: {system_context}]\n\n{prompt}"
|
||||||
|
|
||||||
result = await collect_copilot_response(
|
outcome, result = await run_copilot_turn_via_queue(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
message=effective_prompt,
|
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
message=effective_prompt,
|
||||||
|
# Graph block execution is synchronous from the caller's
|
||||||
|
# perspective — wait effectively as long as needed. The
|
||||||
|
# SDK enforces its own idle-based timeout inside the
|
||||||
|
# stream_registry pipeline.
|
||||||
|
timeout=_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS,
|
||||||
permissions=effective_permissions,
|
permissions=effective_permissions,
|
||||||
|
tool_call_id=_AUTOPILOT_TOOL_CALL_ID,
|
||||||
|
tool_name=_AUTOPILOT_TOOL_NAME,
|
||||||
)
|
)
|
||||||
|
if outcome == "failed":
|
||||||
|
raise RuntimeError(
|
||||||
|
"AutoPilot turn failed — see the session's transcript"
|
||||||
|
)
|
||||||
|
if outcome == "running":
|
||||||
|
raise RuntimeError(
|
||||||
|
"AutoPilot turn did not complete within "
|
||||||
|
f"{_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS}s — session "
|
||||||
|
f"{session_id}"
|
||||||
|
)
|
||||||
|
|
||||||
# Build a lightweight conversation summary from streamed data.
|
# Build a lightweight conversation summary from the aggregated data.
|
||||||
|
# When ``result.queued`` is True the prompt rode on an already-
|
||||||
|
# in-flight turn (``run_copilot_turn_via_queue`` queued it and
|
||||||
|
# waited on the existing turn's stream); the aggregated result
|
||||||
|
# is still valid, so the same rendering path applies.
|
||||||
turn_messages: list[dict[str, Any]] = [
|
turn_messages: list[dict[str, Any]] = [
|
||||||
{"role": "user", "content": effective_prompt},
|
{"role": "user", "content": effective_prompt},
|
||||||
]
|
]
|
||||||
@@ -315,7 +377,7 @@ class AutoPilotBlock(Block):
|
|||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": result.response_text,
|
"content": result.response_text,
|
||||||
"tool_calls": result.tool_calls,
|
"tool_calls": [tc.model_dump() for tc in result.tool_calls],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -326,11 +388,11 @@ class AutoPilotBlock(Block):
|
|||||||
|
|
||||||
tool_calls: list[ToolCallEntry] = [
|
tool_calls: list[ToolCallEntry] = [
|
||||||
{
|
{
|
||||||
"tool_call_id": tc["tool_call_id"],
|
"tool_call_id": tc.tool_call_id,
|
||||||
"tool_name": tc["tool_name"],
|
"tool_name": tc.tool_name,
|
||||||
"input": tc["input"],
|
"input": tc.input,
|
||||||
"output": tc["output"],
|
"output": tc.output,
|
||||||
"success": tc["success"],
|
"success": tc.success,
|
||||||
}
|
}
|
||||||
for tc in result.tool_calls
|
for tc in result.tool_calls
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -62,6 +62,14 @@ class TestBuildAndValidatePermissions:
|
|||||||
with pytest.raises(ValidationError, match="not_a_real_tool"):
|
with pytest.raises(ValidationError, match="not_a_real_tool"):
|
||||||
_make_input(tools=["not_a_real_tool"])
|
_make_input(tools=["not_a_real_tool"])
|
||||||
|
|
||||||
|
async def test_disabled_legacy_tool_is_accepted_and_removed(self):
|
||||||
|
inp = _make_input(tools=["ask_question", "run_block"])
|
||||||
|
result = await _build_and_validate_permissions(inp)
|
||||||
|
|
||||||
|
assert inp.tools == ["run_block"]
|
||||||
|
assert isinstance(result, CopilotPermissions)
|
||||||
|
assert result.tools == ["run_block"]
|
||||||
|
|
||||||
async def test_valid_block_name_accepted(self):
|
async def test_valid_block_name_accepted(self):
|
||||||
mock_block_cls = MagicMock()
|
mock_block_cls = MagicMock()
|
||||||
mock_block_cls.return_value.name = "HTTP Request"
|
mock_block_cls.return_value.name = "HTTP Request"
|
||||||
|
|||||||
26
autogpt_platform/backend/backend/blocks/ayrshare/_config.py
Normal file
26
autogpt_platform/backend/backend/blocks/ayrshare/_config.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""Shared provider config for Ayrshare social-media blocks.
|
||||||
|
|
||||||
|
The "credential" exposed to blocks is the **per-user Ayrshare profile key**,
|
||||||
|
not the org-level ``AYRSHARE_API_KEY``. Profile keys are provisioned per
|
||||||
|
user by :class:`~backend.integrations.managed_providers.ayrshare.AyrshareManagedProvider`
|
||||||
|
and stored in the normal credentials list with ``is_managed=True``, so every
|
||||||
|
Ayrshare block fits the standard credential flow:
|
||||||
|
|
||||||
|
credentials: CredentialsMetaInput = ayrshare.credentials_field(...)
|
||||||
|
|
||||||
|
``run_block`` / ``resolve_block_credentials`` take care of the rest.
|
||||||
|
|
||||||
|
``with_managed_api_key()`` registers ``api_key`` as a supported auth type
|
||||||
|
without the env-var-backed default credential that ``with_api_key()`` would
|
||||||
|
create — the org-level ``AYRSHARE_API_KEY`` is the admin key and must never
|
||||||
|
reach a block as a "profile key".
|
||||||
|
"""
|
||||||
|
|
||||||
|
from backend.sdk import ProviderBuilder
|
||||||
|
|
||||||
|
ayrshare = (
|
||||||
|
ProviderBuilder("ayrshare")
|
||||||
|
.with_description("Post to every social network")
|
||||||
|
.with_managed_api_key()
|
||||||
|
.build()
|
||||||
|
)
|
||||||
18
autogpt_platform/backend/backend/blocks/ayrshare/_cost.py
Normal file
18
autogpt_platform/backend/backend/blocks/ayrshare/_cost.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from backend.sdk import BlockCost, BlockCostType
|
||||||
|
|
||||||
|
# Ayrshare is a subscription proxy ($149/mo Business). Per-post credit charges
|
||||||
|
# prevent a single heavy user from absorbing the fixed cost and align with the
|
||||||
|
# upload cost of each post variant.
|
||||||
|
# cost_filter matches on input_data.is_video BEFORE run() executes, so the flag
|
||||||
|
# has to be correct at input-eval time. Video-only platforms (YouTube, Snapchat)
|
||||||
|
# override the base default to True; platforms that accept both (TikTok, etc.)
|
||||||
|
# rely on the caller setting is_video explicitly for accurate billing.
|
||||||
|
# First match wins in block_usage_cost, so list the video tier first.
|
||||||
|
AYRSHARE_POST_COSTS = (
|
||||||
|
BlockCost(
|
||||||
|
cost_amount=5, cost_type=BlockCostType.RUN, cost_filter={"is_video": True}
|
||||||
|
),
|
||||||
|
BlockCost(
|
||||||
|
cost_amount=2, cost_type=BlockCostType.RUN, cost_filter={"is_video": False}
|
||||||
|
),
|
||||||
|
)
|
||||||
@@ -4,22 +4,25 @@ from typing import Optional
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.blocks._base import BlockSchemaInput
|
from backend.blocks._base import BlockSchemaInput
|
||||||
from backend.data.model import SchemaField, UserIntegrations
|
from backend.data.model import CredentialsMetaInput, SchemaField
|
||||||
from backend.integrations.ayrshare import AyrshareClient
|
from backend.integrations.ayrshare import AyrshareClient
|
||||||
from backend.util.clients import get_database_manager_async_client
|
|
||||||
from backend.util.exceptions import MissingConfigError
|
from backend.util.exceptions import MissingConfigError
|
||||||
|
|
||||||
|
from ._config import ayrshare
|
||||||
async def get_profile_key(user_id: str):
|
|
||||||
user_integrations: UserIntegrations = (
|
|
||||||
await get_database_manager_async_client().get_user_integrations(user_id)
|
|
||||||
)
|
|
||||||
return user_integrations.managed_credentials.ayrshare_profile_key
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAyrshareInput(BlockSchemaInput):
|
class BaseAyrshareInput(BlockSchemaInput):
|
||||||
"""Base input model for Ayrshare social media posts with common fields."""
|
"""Base input model for Ayrshare social media posts with common fields."""
|
||||||
|
|
||||||
|
credentials: CredentialsMetaInput = ayrshare.credentials_field(
|
||||||
|
description=(
|
||||||
|
"Ayrshare profile credential. AutoGPT provisions this managed "
|
||||||
|
"credential automatically — the user does not create it. After "
|
||||||
|
"it's in place, the user links each social account via the "
|
||||||
|
"Ayrshare SSO popup in the Builder."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
post: str = SchemaField(
|
post: str = SchemaField(
|
||||||
description="The post text to be published", default="", advanced=False
|
description="The post text to be published", default="", advanced=False
|
||||||
)
|
)
|
||||||
@@ -29,7 +32,9 @@ class BaseAyrshareInput(BlockSchemaInput):
|
|||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
is_video: bool = SchemaField(
|
is_video: bool = SchemaField(
|
||||||
description="Whether the media is a video", default=False, advanced=True
|
description="Whether the media is a video. Set to True when uploading a video so billing applies the video tier.",
|
||||||
|
default=False,
|
||||||
|
advanced=True,
|
||||||
)
|
)
|
||||||
schedule_date: Optional[datetime] = SchemaField(
|
schedule_date: Optional[datetime] = SchemaField(
|
||||||
description="UTC datetime for scheduling (YYYY-MM-DDThh:mm:ssZ)",
|
description="UTC datetime for scheduling (YYYY-MM-DDThh:mm:ssZ)",
|
||||||
|
|||||||
@@ -1,16 +1,20 @@
|
|||||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
|
cost,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
from ._cost import AYRSHARE_POST_COSTS
|
||||||
|
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||||
|
|
||||||
|
|
||||||
|
@cost(*AYRSHARE_POST_COSTS)
|
||||||
class PostToBlueskyBlock(Block):
|
class PostToBlueskyBlock(Block):
|
||||||
"""Block for posting to Bluesky with Bluesky-specific options."""
|
"""Block for posting to Bluesky with Bluesky-specific options."""
|
||||||
|
|
||||||
@@ -57,16 +61,10 @@ class PostToBlueskyBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: "PostToBlueskyBlock.Input",
|
input_data: "PostToBlueskyBlock.Input",
|
||||||
*,
|
*,
|
||||||
user_id: str,
|
credentials: APIKeyCredentials,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
"""Post to Bluesky with Bluesky-specific options."""
|
"""Post to Bluesky with Bluesky-specific options."""
|
||||||
|
|
||||||
profile_key = await get_profile_key(user_id)
|
|
||||||
if not profile_key:
|
|
||||||
yield "error", "Please link a social account via Ayrshare"
|
|
||||||
return
|
|
||||||
|
|
||||||
client = create_ayrshare_client()
|
client = create_ayrshare_client()
|
||||||
if not client:
|
if not client:
|
||||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||||
@@ -106,7 +104,7 @@ class PostToBlueskyBlock(Block):
|
|||||||
random_media_url=input_data.random_media_url,
|
random_media_url=input_data.random_media_url,
|
||||||
notes=input_data.notes,
|
notes=input_data.notes,
|
||||||
bluesky_options=bluesky_options if bluesky_options else None,
|
bluesky_options=bluesky_options if bluesky_options else None,
|
||||||
profile_key=profile_key.get_secret_value(),
|
profile_key=credentials.api_key.get_secret_value(),
|
||||||
)
|
)
|
||||||
yield "post_result", response
|
yield "post_result", response
|
||||||
if response.postIds:
|
if response.postIds:
|
||||||
|
|||||||
@@ -1,21 +1,20 @@
|
|||||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
|
cost,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._util import (
|
from ._cost import AYRSHARE_POST_COSTS
|
||||||
BaseAyrshareInput,
|
from ._util import BaseAyrshareInput, CarouselItem, create_ayrshare_client
|
||||||
CarouselItem,
|
|
||||||
create_ayrshare_client,
|
|
||||||
get_profile_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
@cost(*AYRSHARE_POST_COSTS)
|
||||||
class PostToFacebookBlock(Block):
|
class PostToFacebookBlock(Block):
|
||||||
"""Block for posting to Facebook with Facebook-specific options."""
|
"""Block for posting to Facebook with Facebook-specific options."""
|
||||||
|
|
||||||
@@ -120,15 +119,10 @@ class PostToFacebookBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: "PostToFacebookBlock.Input",
|
input_data: "PostToFacebookBlock.Input",
|
||||||
*,
|
*,
|
||||||
user_id: str,
|
credentials: APIKeyCredentials,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
"""Post to Facebook with Facebook-specific options."""
|
"""Post to Facebook with Facebook-specific options."""
|
||||||
profile_key = await get_profile_key(user_id)
|
|
||||||
if not profile_key:
|
|
||||||
yield "error", "Please link a social account via Ayrshare"
|
|
||||||
return
|
|
||||||
|
|
||||||
client = create_ayrshare_client()
|
client = create_ayrshare_client()
|
||||||
if not client:
|
if not client:
|
||||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||||
@@ -204,7 +198,7 @@ class PostToFacebookBlock(Block):
|
|||||||
random_media_url=input_data.random_media_url,
|
random_media_url=input_data.random_media_url,
|
||||||
notes=input_data.notes,
|
notes=input_data.notes,
|
||||||
facebook_options=facebook_options if facebook_options else None,
|
facebook_options=facebook_options if facebook_options else None,
|
||||||
profile_key=profile_key.get_secret_value(),
|
profile_key=credentials.api_key.get_secret_value(),
|
||||||
)
|
)
|
||||||
yield "post_result", response
|
yield "post_result", response
|
||||||
if response.postIds:
|
if response.postIds:
|
||||||
|
|||||||
@@ -1,16 +1,20 @@
|
|||||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
|
cost,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
from ._cost import AYRSHARE_POST_COSTS
|
||||||
|
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||||
|
|
||||||
|
|
||||||
|
@cost(*AYRSHARE_POST_COSTS)
|
||||||
class PostToGMBBlock(Block):
|
class PostToGMBBlock(Block):
|
||||||
"""Block for posting to Google My Business with GMB-specific options."""
|
"""Block for posting to Google My Business with GMB-specific options."""
|
||||||
|
|
||||||
@@ -110,14 +114,13 @@ class PostToGMBBlock(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: "PostToGMBBlock.Input", *, user_id: str, **kwargs
|
self,
|
||||||
|
input_data: "PostToGMBBlock.Input",
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
**kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
"""Post to Google My Business with GMB-specific options."""
|
"""Post to Google My Business with GMB-specific options."""
|
||||||
profile_key = await get_profile_key(user_id)
|
|
||||||
if not profile_key:
|
|
||||||
yield "error", "Please link a social account via Ayrshare"
|
|
||||||
return
|
|
||||||
|
|
||||||
client = create_ayrshare_client()
|
client = create_ayrshare_client()
|
||||||
if not client:
|
if not client:
|
||||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||||
@@ -202,7 +205,7 @@ class PostToGMBBlock(Block):
|
|||||||
random_media_url=input_data.random_media_url,
|
random_media_url=input_data.random_media_url,
|
||||||
notes=input_data.notes,
|
notes=input_data.notes,
|
||||||
gmb_options=gmb_options if gmb_options else None,
|
gmb_options=gmb_options if gmb_options else None,
|
||||||
profile_key=profile_key.get_secret_value(),
|
profile_key=credentials.api_key.get_secret_value(),
|
||||||
)
|
)
|
||||||
yield "post_result", response
|
yield "post_result", response
|
||||||
if response.postIds:
|
if response.postIds:
|
||||||
|
|||||||
@@ -2,22 +2,21 @@ from typing import Any
|
|||||||
|
|
||||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
|
cost,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._util import (
|
from ._cost import AYRSHARE_POST_COSTS
|
||||||
BaseAyrshareInput,
|
from ._util import BaseAyrshareInput, InstagramUserTag, create_ayrshare_client
|
||||||
InstagramUserTag,
|
|
||||||
create_ayrshare_client,
|
|
||||||
get_profile_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
@cost(*AYRSHARE_POST_COSTS)
|
||||||
class PostToInstagramBlock(Block):
|
class PostToInstagramBlock(Block):
|
||||||
"""Block for posting to Instagram with Instagram-specific options."""
|
"""Block for posting to Instagram with Instagram-specific options."""
|
||||||
|
|
||||||
@@ -112,15 +111,10 @@ class PostToInstagramBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: "PostToInstagramBlock.Input",
|
input_data: "PostToInstagramBlock.Input",
|
||||||
*,
|
*,
|
||||||
user_id: str,
|
credentials: APIKeyCredentials,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
"""Post to Instagram with Instagram-specific options."""
|
"""Post to Instagram with Instagram-specific options."""
|
||||||
profile_key = await get_profile_key(user_id)
|
|
||||||
if not profile_key:
|
|
||||||
yield "error", "Please link a social account via Ayrshare"
|
|
||||||
return
|
|
||||||
|
|
||||||
client = create_ayrshare_client()
|
client = create_ayrshare_client()
|
||||||
if not client:
|
if not client:
|
||||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||||
@@ -241,7 +235,7 @@ class PostToInstagramBlock(Block):
|
|||||||
random_media_url=input_data.random_media_url,
|
random_media_url=input_data.random_media_url,
|
||||||
notes=input_data.notes,
|
notes=input_data.notes,
|
||||||
instagram_options=instagram_options if instagram_options else None,
|
instagram_options=instagram_options if instagram_options else None,
|
||||||
profile_key=profile_key.get_secret_value(),
|
profile_key=credentials.api_key.get_secret_value(),
|
||||||
)
|
)
|
||||||
yield "post_result", response
|
yield "post_result", response
|
||||||
if response.postIds:
|
if response.postIds:
|
||||||
|
|||||||
@@ -1,16 +1,20 @@
|
|||||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
|
cost,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
from ._cost import AYRSHARE_POST_COSTS
|
||||||
|
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||||
|
|
||||||
|
|
||||||
|
@cost(*AYRSHARE_POST_COSTS)
|
||||||
class PostToLinkedInBlock(Block):
|
class PostToLinkedInBlock(Block):
|
||||||
"""Block for posting to LinkedIn with LinkedIn-specific options."""
|
"""Block for posting to LinkedIn with LinkedIn-specific options."""
|
||||||
|
|
||||||
@@ -112,15 +116,10 @@ class PostToLinkedInBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: "PostToLinkedInBlock.Input",
|
input_data: "PostToLinkedInBlock.Input",
|
||||||
*,
|
*,
|
||||||
user_id: str,
|
credentials: APIKeyCredentials,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
"""Post to LinkedIn with LinkedIn-specific options."""
|
"""Post to LinkedIn with LinkedIn-specific options."""
|
||||||
profile_key = await get_profile_key(user_id)
|
|
||||||
if not profile_key:
|
|
||||||
yield "error", "Please link a social account via Ayrshare"
|
|
||||||
return
|
|
||||||
|
|
||||||
client = create_ayrshare_client()
|
client = create_ayrshare_client()
|
||||||
if not client:
|
if not client:
|
||||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||||
@@ -214,7 +213,7 @@ class PostToLinkedInBlock(Block):
|
|||||||
random_media_url=input_data.random_media_url,
|
random_media_url=input_data.random_media_url,
|
||||||
notes=input_data.notes,
|
notes=input_data.notes,
|
||||||
linkedin_options=linkedin_options if linkedin_options else None,
|
linkedin_options=linkedin_options if linkedin_options else None,
|
||||||
profile_key=profile_key.get_secret_value(),
|
profile_key=credentials.api_key.get_secret_value(),
|
||||||
)
|
)
|
||||||
yield "post_result", response
|
yield "post_result", response
|
||||||
if response.postIds:
|
if response.postIds:
|
||||||
|
|||||||
@@ -1,21 +1,20 @@
|
|||||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
|
cost,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._util import (
|
from ._cost import AYRSHARE_POST_COSTS
|
||||||
BaseAyrshareInput,
|
from ._util import BaseAyrshareInput, PinterestCarouselOption, create_ayrshare_client
|
||||||
PinterestCarouselOption,
|
|
||||||
create_ayrshare_client,
|
|
||||||
get_profile_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
@cost(*AYRSHARE_POST_COSTS)
|
||||||
class PostToPinterestBlock(Block):
|
class PostToPinterestBlock(Block):
|
||||||
"""Block for posting to Pinterest with Pinterest-specific options."""
|
"""Block for posting to Pinterest with Pinterest-specific options."""
|
||||||
|
|
||||||
@@ -92,15 +91,10 @@ class PostToPinterestBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: "PostToPinterestBlock.Input",
|
input_data: "PostToPinterestBlock.Input",
|
||||||
*,
|
*,
|
||||||
user_id: str,
|
credentials: APIKeyCredentials,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
"""Post to Pinterest with Pinterest-specific options."""
|
"""Post to Pinterest with Pinterest-specific options."""
|
||||||
profile_key = await get_profile_key(user_id)
|
|
||||||
if not profile_key:
|
|
||||||
yield "error", "Please link a social account via Ayrshare"
|
|
||||||
return
|
|
||||||
|
|
||||||
client = create_ayrshare_client()
|
client = create_ayrshare_client()
|
||||||
if not client:
|
if not client:
|
||||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||||
@@ -206,7 +200,7 @@ class PostToPinterestBlock(Block):
|
|||||||
random_media_url=input_data.random_media_url,
|
random_media_url=input_data.random_media_url,
|
||||||
notes=input_data.notes,
|
notes=input_data.notes,
|
||||||
pinterest_options=pinterest_options if pinterest_options else None,
|
pinterest_options=pinterest_options if pinterest_options else None,
|
||||||
profile_key=profile_key.get_secret_value(),
|
profile_key=credentials.api_key.get_secret_value(),
|
||||||
)
|
)
|
||||||
yield "post_result", response
|
yield "post_result", response
|
||||||
if response.postIds:
|
if response.postIds:
|
||||||
|
|||||||
@@ -1,16 +1,20 @@
|
|||||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
|
cost,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
from ._cost import AYRSHARE_POST_COSTS
|
||||||
|
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||||
|
|
||||||
|
|
||||||
|
@cost(*AYRSHARE_POST_COSTS)
|
||||||
class PostToRedditBlock(Block):
|
class PostToRedditBlock(Block):
|
||||||
"""Block for posting to Reddit."""
|
"""Block for posting to Reddit."""
|
||||||
|
|
||||||
@@ -35,12 +39,12 @@ class PostToRedditBlock(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: "PostToRedditBlock.Input", *, user_id: str, **kwargs
|
self,
|
||||||
|
input_data: "PostToRedditBlock.Input",
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
**kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
profile_key = await get_profile_key(user_id)
|
|
||||||
if not profile_key:
|
|
||||||
yield "error", "Please link a social account via Ayrshare"
|
|
||||||
return
|
|
||||||
client = create_ayrshare_client()
|
client = create_ayrshare_client()
|
||||||
if not client:
|
if not client:
|
||||||
yield "error", "Ayrshare integration is not configured."
|
yield "error", "Ayrshare integration is not configured."
|
||||||
@@ -61,7 +65,7 @@ class PostToRedditBlock(Block):
|
|||||||
random_post=input_data.random_post,
|
random_post=input_data.random_post,
|
||||||
random_media_url=input_data.random_media_url,
|
random_media_url=input_data.random_media_url,
|
||||||
notes=input_data.notes,
|
notes=input_data.notes,
|
||||||
profile_key=profile_key.get_secret_value(),
|
profile_key=credentials.api_key.get_secret_value(),
|
||||||
)
|
)
|
||||||
yield "post_result", response
|
yield "post_result", response
|
||||||
if response.postIds:
|
if response.postIds:
|
||||||
|
|||||||
@@ -1,16 +1,20 @@
|
|||||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
|
cost,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
from ._cost import AYRSHARE_POST_COSTS
|
||||||
|
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||||
|
|
||||||
|
|
||||||
|
@cost(*AYRSHARE_POST_COSTS)
|
||||||
class PostToSnapchatBlock(Block):
|
class PostToSnapchatBlock(Block):
|
||||||
"""Block for posting to Snapchat with Snapchat-specific options."""
|
"""Block for posting to Snapchat with Snapchat-specific options."""
|
||||||
|
|
||||||
@@ -31,6 +35,14 @@ class PostToSnapchatBlock(Block):
|
|||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Snapchat is video-only; override the base default so the @cost filter
|
||||||
|
# selects the 5-credit video tier instead of the 2-credit image tier.
|
||||||
|
is_video: bool = SchemaField(
|
||||||
|
description="Whether the media is a video (always True for Snapchat)",
|
||||||
|
default=True,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Snapchat-specific options
|
# Snapchat-specific options
|
||||||
story_type: str = SchemaField(
|
story_type: str = SchemaField(
|
||||||
description="Type of Snapchat content: 'story' (24-hour Stories), 'saved_story' (Saved Stories), or 'spotlight' (Spotlight posts)",
|
description="Type of Snapchat content: 'story' (24-hour Stories), 'saved_story' (Saved Stories), or 'spotlight' (Spotlight posts)",
|
||||||
@@ -62,15 +74,10 @@ class PostToSnapchatBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: "PostToSnapchatBlock.Input",
|
input_data: "PostToSnapchatBlock.Input",
|
||||||
*,
|
*,
|
||||||
user_id: str,
|
credentials: APIKeyCredentials,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
"""Post to Snapchat with Snapchat-specific options."""
|
"""Post to Snapchat with Snapchat-specific options."""
|
||||||
profile_key = await get_profile_key(user_id)
|
|
||||||
if not profile_key:
|
|
||||||
yield "error", "Please link a social account via Ayrshare"
|
|
||||||
return
|
|
||||||
|
|
||||||
client = create_ayrshare_client()
|
client = create_ayrshare_client()
|
||||||
if not client:
|
if not client:
|
||||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||||
@@ -121,7 +128,7 @@ class PostToSnapchatBlock(Block):
|
|||||||
random_media_url=input_data.random_media_url,
|
random_media_url=input_data.random_media_url,
|
||||||
notes=input_data.notes,
|
notes=input_data.notes,
|
||||||
snapchat_options=snapchat_options if snapchat_options else None,
|
snapchat_options=snapchat_options if snapchat_options else None,
|
||||||
profile_key=profile_key.get_secret_value(),
|
profile_key=credentials.api_key.get_secret_value(),
|
||||||
)
|
)
|
||||||
yield "post_result", response
|
yield "post_result", response
|
||||||
if response.postIds:
|
if response.postIds:
|
||||||
|
|||||||
@@ -1,16 +1,20 @@
|
|||||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
|
cost,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
from ._cost import AYRSHARE_POST_COSTS
|
||||||
|
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||||
|
|
||||||
|
|
||||||
|
@cost(*AYRSHARE_POST_COSTS)
|
||||||
class PostToTelegramBlock(Block):
|
class PostToTelegramBlock(Block):
|
||||||
"""Block for posting to Telegram with Telegram-specific options."""
|
"""Block for posting to Telegram with Telegram-specific options."""
|
||||||
|
|
||||||
@@ -57,15 +61,10 @@ class PostToTelegramBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: "PostToTelegramBlock.Input",
|
input_data: "PostToTelegramBlock.Input",
|
||||||
*,
|
*,
|
||||||
user_id: str,
|
credentials: APIKeyCredentials,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
"""Post to Telegram with Telegram-specific validation."""
|
"""Post to Telegram with Telegram-specific validation."""
|
||||||
profile_key = await get_profile_key(user_id)
|
|
||||||
if not profile_key:
|
|
||||||
yield "error", "Please link a social account via Ayrshare"
|
|
||||||
return
|
|
||||||
|
|
||||||
client = create_ayrshare_client()
|
client = create_ayrshare_client()
|
||||||
if not client:
|
if not client:
|
||||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||||
@@ -108,7 +107,7 @@ class PostToTelegramBlock(Block):
|
|||||||
random_post=input_data.random_post,
|
random_post=input_data.random_post,
|
||||||
random_media_url=input_data.random_media_url,
|
random_media_url=input_data.random_media_url,
|
||||||
notes=input_data.notes,
|
notes=input_data.notes,
|
||||||
profile_key=profile_key.get_secret_value(),
|
profile_key=credentials.api_key.get_secret_value(),
|
||||||
)
|
)
|
||||||
yield "post_result", response
|
yield "post_result", response
|
||||||
if response.postIds:
|
if response.postIds:
|
||||||
|
|||||||
@@ -1,16 +1,20 @@
|
|||||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
|
cost,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
from ._cost import AYRSHARE_POST_COSTS
|
||||||
|
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||||
|
|
||||||
|
|
||||||
|
@cost(*AYRSHARE_POST_COSTS)
|
||||||
class PostToThreadsBlock(Block):
|
class PostToThreadsBlock(Block):
|
||||||
"""Block for posting to Threads with Threads-specific options."""
|
"""Block for posting to Threads with Threads-specific options."""
|
||||||
|
|
||||||
@@ -50,15 +54,10 @@ class PostToThreadsBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: "PostToThreadsBlock.Input",
|
input_data: "PostToThreadsBlock.Input",
|
||||||
*,
|
*,
|
||||||
user_id: str,
|
credentials: APIKeyCredentials,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
"""Post to Threads with Threads-specific validation."""
|
"""Post to Threads with Threads-specific validation."""
|
||||||
profile_key = await get_profile_key(user_id)
|
|
||||||
if not profile_key:
|
|
||||||
yield "error", "Please link a social account via Ayrshare"
|
|
||||||
return
|
|
||||||
|
|
||||||
client = create_ayrshare_client()
|
client = create_ayrshare_client()
|
||||||
if not client:
|
if not client:
|
||||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||||
@@ -103,7 +102,7 @@ class PostToThreadsBlock(Block):
|
|||||||
random_media_url=input_data.random_media_url,
|
random_media_url=input_data.random_media_url,
|
||||||
notes=input_data.notes,
|
notes=input_data.notes,
|
||||||
threads_options=threads_options if threads_options else None,
|
threads_options=threads_options if threads_options else None,
|
||||||
profile_key=profile_key.get_secret_value(),
|
profile_key=credentials.api_key.get_secret_value(),
|
||||||
)
|
)
|
||||||
yield "post_result", response
|
yield "post_result", response
|
||||||
if response.postIds:
|
if response.postIds:
|
||||||
|
|||||||
@@ -2,15 +2,18 @@ from enum import Enum
|
|||||||
|
|
||||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
|
cost,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
from ._cost import AYRSHARE_POST_COSTS
|
||||||
|
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||||
|
|
||||||
|
|
||||||
class TikTokVisibility(str, Enum):
|
class TikTokVisibility(str, Enum):
|
||||||
@@ -19,6 +22,7 @@ class TikTokVisibility(str, Enum):
|
|||||||
FOLLOWERS = "followers"
|
FOLLOWERS = "followers"
|
||||||
|
|
||||||
|
|
||||||
|
@cost(*AYRSHARE_POST_COSTS)
|
||||||
class PostToTikTokBlock(Block):
|
class PostToTikTokBlock(Block):
|
||||||
"""Block for posting to TikTok with TikTok-specific options."""
|
"""Block for posting to TikTok with TikTok-specific options."""
|
||||||
|
|
||||||
@@ -113,14 +117,13 @@ class PostToTikTokBlock(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: "PostToTikTokBlock.Input", *, user_id: str, **kwargs
|
self,
|
||||||
|
input_data: "PostToTikTokBlock.Input",
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
"""Post to TikTok with TikTok-specific validation and options."""
|
"""Post to TikTok with TikTok-specific validation and options."""
|
||||||
profile_key = await get_profile_key(user_id)
|
|
||||||
if not profile_key:
|
|
||||||
yield "error", "Please link a social account via Ayrshare"
|
|
||||||
return
|
|
||||||
|
|
||||||
client = create_ayrshare_client()
|
client = create_ayrshare_client()
|
||||||
if not client:
|
if not client:
|
||||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||||
@@ -235,7 +238,7 @@ class PostToTikTokBlock(Block):
|
|||||||
random_media_url=input_data.random_media_url,
|
random_media_url=input_data.random_media_url,
|
||||||
notes=input_data.notes,
|
notes=input_data.notes,
|
||||||
tiktok_options=tiktok_options if tiktok_options else None,
|
tiktok_options=tiktok_options if tiktok_options else None,
|
||||||
profile_key=profile_key.get_secret_value(),
|
profile_key=credentials.api_key.get_secret_value(),
|
||||||
)
|
)
|
||||||
yield "post_result", response
|
yield "post_result", response
|
||||||
if response.postIds:
|
if response.postIds:
|
||||||
|
|||||||
@@ -1,16 +1,20 @@
|
|||||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
|
cost,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
from ._cost import AYRSHARE_POST_COSTS
|
||||||
|
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||||
|
|
||||||
|
|
||||||
|
@cost(*AYRSHARE_POST_COSTS)
|
||||||
class PostToXBlock(Block):
|
class PostToXBlock(Block):
|
||||||
"""Block for posting to X / Twitter with Twitter-specific options."""
|
"""Block for posting to X / Twitter with Twitter-specific options."""
|
||||||
|
|
||||||
@@ -115,15 +119,10 @@ class PostToXBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: "PostToXBlock.Input",
|
input_data: "PostToXBlock.Input",
|
||||||
*,
|
*,
|
||||||
user_id: str,
|
credentials: APIKeyCredentials,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
"""Post to X / Twitter with enhanced X-specific options."""
|
"""Post to X / Twitter with enhanced X-specific options."""
|
||||||
profile_key = await get_profile_key(user_id)
|
|
||||||
if not profile_key:
|
|
||||||
yield "error", "Please link a social account via Ayrshare"
|
|
||||||
return
|
|
||||||
|
|
||||||
client = create_ayrshare_client()
|
client = create_ayrshare_client()
|
||||||
if not client:
|
if not client:
|
||||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||||
@@ -233,7 +232,7 @@ class PostToXBlock(Block):
|
|||||||
random_media_url=input_data.random_media_url,
|
random_media_url=input_data.random_media_url,
|
||||||
notes=input_data.notes,
|
notes=input_data.notes,
|
||||||
twitter_options=twitter_options if twitter_options else None,
|
twitter_options=twitter_options if twitter_options else None,
|
||||||
profile_key=profile_key.get_secret_value(),
|
profile_key=credentials.api_key.get_secret_value(),
|
||||||
)
|
)
|
||||||
yield "post_result", response
|
yield "post_result", response
|
||||||
if response.postIds:
|
if response.postIds:
|
||||||
|
|||||||
@@ -3,15 +3,18 @@ from typing import Any
|
|||||||
|
|
||||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
|
cost,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
from ._cost import AYRSHARE_POST_COSTS
|
||||||
|
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||||
|
|
||||||
|
|
||||||
class YouTubeVisibility(str, Enum):
|
class YouTubeVisibility(str, Enum):
|
||||||
@@ -20,6 +23,7 @@ class YouTubeVisibility(str, Enum):
|
|||||||
UNLISTED = "unlisted"
|
UNLISTED = "unlisted"
|
||||||
|
|
||||||
|
|
||||||
|
@cost(*AYRSHARE_POST_COSTS)
|
||||||
class PostToYouTubeBlock(Block):
|
class PostToYouTubeBlock(Block):
|
||||||
"""Block for posting to YouTube with YouTube-specific options."""
|
"""Block for posting to YouTube with YouTube-specific options."""
|
||||||
|
|
||||||
@@ -39,6 +43,14 @@ class PostToYouTubeBlock(Block):
|
|||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# YouTube is video-only; override the base default so the @cost filter
|
||||||
|
# selects the 5-credit video tier instead of the 2-credit image tier.
|
||||||
|
is_video: bool = SchemaField(
|
||||||
|
description="Whether the media is a video (always True for YouTube)",
|
||||||
|
default=True,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
# YouTube-specific required options
|
# YouTube-specific required options
|
||||||
title: str = SchemaField(
|
title: str = SchemaField(
|
||||||
description="Video title (max 100 chars, required). Cannot contain < or > characters.",
|
description="Video title (max 100 chars, required). Cannot contain < or > characters.",
|
||||||
@@ -137,16 +149,10 @@ class PostToYouTubeBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: "PostToYouTubeBlock.Input",
|
input_data: "PostToYouTubeBlock.Input",
|
||||||
*,
|
*,
|
||||||
user_id: str,
|
credentials: APIKeyCredentials,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
"""Post to YouTube with YouTube-specific validation and options."""
|
"""Post to YouTube with YouTube-specific validation and options."""
|
||||||
|
|
||||||
profile_key = await get_profile_key(user_id)
|
|
||||||
if not profile_key:
|
|
||||||
yield "error", "Please link a social account via Ayrshare"
|
|
||||||
return
|
|
||||||
|
|
||||||
client = create_ayrshare_client()
|
client = create_ayrshare_client()
|
||||||
if not client:
|
if not client:
|
||||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||||
@@ -302,7 +308,7 @@ class PostToYouTubeBlock(Block):
|
|||||||
random_media_url=input_data.random_media_url,
|
random_media_url=input_data.random_media_url,
|
||||||
notes=input_data.notes,
|
notes=input_data.notes,
|
||||||
youtube_options=youtube_options,
|
youtube_options=youtube_options,
|
||||||
profile_key=profile_key.get_secret_value(),
|
profile_key=credentials.api_key.get_secret_value(),
|
||||||
)
|
)
|
||||||
yield "post_result", response
|
yield "post_result", response
|
||||||
if response.postIds:
|
if response.postIds:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from backend.sdk import BlockCostType, ProviderBuilder
|
|||||||
# Configure the Meeting BaaS provider with API key authentication
|
# Configure the Meeting BaaS provider with API key authentication
|
||||||
baas = (
|
baas = (
|
||||||
ProviderBuilder("baas")
|
ProviderBuilder("baas")
|
||||||
|
.with_description("Meeting recording and transcription")
|
||||||
.with_api_key("MEETING_BAAS_API_KEY", "Meeting BaaS API Key")
|
.with_api_key("MEETING_BAAS_API_KEY", "Meeting BaaS API Key")
|
||||||
.with_base_cost(5, BlockCostType.RUN) # Higher cost for meeting recording service
|
.with_base_cost(5, BlockCostType.RUN) # Higher cost for meeting recording service
|
||||||
.build()
|
.build()
|
||||||
|
|||||||
@@ -4,21 +4,34 @@ Meeting BaaS bot (recording) blocks.
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from backend.data.model import NodeExecutionStats
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
|
BlockCost,
|
||||||
|
BlockCostType,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
|
cost,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._api import MeetingBaasAPI
|
from ._api import MeetingBaasAPI
|
||||||
from ._config import baas
|
from ._config import baas
|
||||||
|
|
||||||
|
# Meeting BaaS recording rate: $0.69 per hour.
|
||||||
|
_MEETING_BAAS_USD_PER_SECOND = 0.69 / 3600
|
||||||
|
|
||||||
|
# Join bills a flat 30 cr commit (covers median short meeting);
|
||||||
|
# FetchMeetingData bills the duration-scaled remainder from the
|
||||||
|
# `duration_seconds` field on the API response. Long meetings no
|
||||||
|
# longer under-bill.
|
||||||
|
|
||||||
|
|
||||||
|
@cost(BlockCost(cost_type=BlockCostType.RUN, cost_amount=30))
|
||||||
class BaasBotJoinMeetingBlock(Block):
|
class BaasBotJoinMeetingBlock(Block):
|
||||||
"""
|
"""
|
||||||
Deploy a bot immediately or at a scheduled start_time to join and record a meeting.
|
Deploy a bot immediately or at a scheduled start_time to join and record a meeting.
|
||||||
@@ -134,6 +147,7 @@ class BaasBotLeaveMeetingBlock(Block):
|
|||||||
yield "left", left
|
yield "left", left
|
||||||
|
|
||||||
|
|
||||||
|
@cost(BlockCost(cost_type=BlockCostType.COST_USD, cost_amount=150))
|
||||||
class BaasBotFetchMeetingDataBlock(Block):
|
class BaasBotFetchMeetingDataBlock(Block):
|
||||||
"""
|
"""
|
||||||
Pull MP4 URL, transcript & metadata for a completed meeting.
|
Pull MP4 URL, transcript & metadata for a completed meeting.
|
||||||
@@ -176,9 +190,21 @@ class BaasBotFetchMeetingDataBlock(Block):
|
|||||||
include_transcripts=input_data.include_transcripts,
|
include_transcripts=input_data.include_transcripts,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
bot_meta = data.get("bot_data", {}).get("bot", {}) or {}
|
||||||
|
# Bill recording duration via COST_USD so multi-hour meetings
|
||||||
|
# scale past the Join block's flat 30 cr deposit.
|
||||||
|
duration_seconds = float(bot_meta.get("duration_seconds") or 0)
|
||||||
|
if duration_seconds > 0:
|
||||||
|
self.merge_stats(
|
||||||
|
NodeExecutionStats(
|
||||||
|
provider_cost=duration_seconds * _MEETING_BAAS_USD_PER_SECOND,
|
||||||
|
provider_cost_type="cost_usd",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
yield "mp4_url", data.get("mp4", "")
|
yield "mp4_url", data.get("mp4", "")
|
||||||
yield "transcript", data.get("bot_data", {}).get("transcripts", [])
|
yield "transcript", data.get("bot_data", {}).get("transcripts", [])
|
||||||
yield "metadata", data.get("bot_data", {}).get("bot", {})
|
yield "metadata", bot_meta
|
||||||
|
|
||||||
|
|
||||||
class BaasBotDeleteRecordingBlock(Block):
|
class BaasBotDeleteRecordingBlock(Block):
|
||||||
|
|||||||
@@ -0,0 +1,86 @@
|
|||||||
|
"""Unit tests for Meeting BaaS duration-based cost emission."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.blocks.baas.bots import (
|
||||||
|
_MEETING_BAAS_USD_PER_SECOND,
|
||||||
|
BaasBotFetchMeetingDataBlock,
|
||||||
|
)
|
||||||
|
from backend.data.model import APIKeyCredentials, NodeExecutionStats
|
||||||
|
|
||||||
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
|
provider="baas",
|
||||||
|
title="Mock BaaS API Key",
|
||||||
|
api_key=SecretStr("mock-baas-api-key"),
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_usd_per_second_derives_from_published_rate():
|
||||||
|
"""$0.69/hour published rate → ~$0.000192/second."""
|
||||||
|
assert _MEETING_BAAS_USD_PER_SECOND == pytest.approx(0.69 / 3600)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"duration_seconds, expected_usd",
|
||||||
|
[
|
||||||
|
(3600, 0.69), # 1 hour
|
||||||
|
(1800, 0.345), # 30 min
|
||||||
|
(0, None), # no recording → no emission
|
||||||
|
(None, None), # missing duration field → no emission
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_fetch_meeting_data_emits_duration_cost_usd(
|
||||||
|
duration_seconds, expected_usd
|
||||||
|
):
|
||||||
|
"""FetchMeetingData extracts duration_seconds from bot metadata and
|
||||||
|
emits provider_cost / cost_usd scaled by the published $0.69/hr rate.
|
||||||
|
Emission is skipped when duration is 0 or missing.
|
||||||
|
"""
|
||||||
|
block = BaasBotFetchMeetingDataBlock()
|
||||||
|
|
||||||
|
bot_meta = {"id": "bot-xyz"}
|
||||||
|
if duration_seconds is not None:
|
||||||
|
bot_meta["duration_seconds"] = duration_seconds
|
||||||
|
|
||||||
|
mock_api = AsyncMock()
|
||||||
|
mock_api.get_meeting_data.return_value = {
|
||||||
|
"mp4": "https://example/recording.mp4",
|
||||||
|
"bot_data": {"bot": bot_meta, "transcripts": []},
|
||||||
|
}
|
||||||
|
|
||||||
|
captured: list[NodeExecutionStats] = []
|
||||||
|
with (
|
||||||
|
patch("backend.blocks.baas.bots.MeetingBaasAPI", return_value=mock_api),
|
||||||
|
patch.object(block, "merge_stats", side_effect=captured.append),
|
||||||
|
):
|
||||||
|
outputs = []
|
||||||
|
async for name, val in block.run(
|
||||||
|
block.input_schema(
|
||||||
|
credentials={
|
||||||
|
"id": TEST_CREDENTIALS.id,
|
||||||
|
"provider": TEST_CREDENTIALS.provider,
|
||||||
|
"type": TEST_CREDENTIALS.type,
|
||||||
|
},
|
||||||
|
bot_id="bot-xyz",
|
||||||
|
include_transcripts=False,
|
||||||
|
),
|
||||||
|
credentials=TEST_CREDENTIALS,
|
||||||
|
):
|
||||||
|
outputs.append((name, val))
|
||||||
|
|
||||||
|
# Always yields the 3 outputs regardless of duration.
|
||||||
|
names = [n for n, _ in outputs]
|
||||||
|
assert "mp4_url" in names and "metadata" in names
|
||||||
|
|
||||||
|
if expected_usd is None:
|
||||||
|
assert captured == []
|
||||||
|
else:
|
||||||
|
assert len(captured) == 1
|
||||||
|
assert captured[0].provider_cost == pytest.approx(expected_usd)
|
||||||
|
assert captured[0].provider_cost_type == "cost_usd"
|
||||||
@@ -2,7 +2,8 @@ from backend.sdk import BlockCostType, ProviderBuilder
|
|||||||
|
|
||||||
bannerbear = (
|
bannerbear = (
|
||||||
ProviderBuilder("bannerbear")
|
ProviderBuilder("bannerbear")
|
||||||
|
.with_description("Auto-generate images and videos")
|
||||||
.with_api_key("BANNERBEAR_API_KEY", "Bannerbear API Key")
|
.with_api_key("BANNERBEAR_API_KEY", "Bannerbear API Key")
|
||||||
.with_base_cost(1, BlockCostType.RUN)
|
.with_base_cost(3, BlockCostType.RUN)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -433,7 +433,7 @@ class TestJinaEmbeddingBlockCostTracking:
|
|||||||
class TestUnrealTextToSpeechBlockCostTracking:
|
class TestUnrealTextToSpeechBlockCostTracking:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_merge_stats_called_with_character_count(self):
|
async def test_merge_stats_called_with_character_count(self):
|
||||||
"""provider_cost equals len(text) with type='characters'."""
|
"""provider_cost = len(text) * $0.000016 with type='cost_usd'."""
|
||||||
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
|
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
|
||||||
from backend.blocks.text_to_speech_block import (
|
from backend.blocks.text_to_speech_block import (
|
||||||
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
|
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
|
||||||
@@ -461,12 +461,12 @@ class TestUnrealTextToSpeechBlockCostTracking:
|
|||||||
|
|
||||||
mock_merge.assert_called_once()
|
mock_merge.assert_called_once()
|
||||||
stats = mock_merge.call_args[0][0]
|
stats = mock_merge.call_args[0][0]
|
||||||
assert stats.provider_cost == float(len(test_text))
|
assert stats.provider_cost == pytest.approx(len(test_text) * 0.000016)
|
||||||
assert stats.provider_cost_type == "characters"
|
assert stats.provider_cost_type == "cost_usd"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_empty_text_gives_zero_characters(self):
|
async def test_empty_text_gives_zero_characters(self):
|
||||||
"""An empty text string results in provider_cost=0.0."""
|
"""An empty text string results in provider_cost=0.0 (cost_usd)."""
|
||||||
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
|
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
|
||||||
from backend.blocks.text_to_speech_block import (
|
from backend.blocks.text_to_speech_block import (
|
||||||
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
|
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
|
||||||
@@ -494,7 +494,7 @@ class TestUnrealTextToSpeechBlockCostTracking:
|
|||||||
mock_merge.assert_called_once()
|
mock_merge.assert_called_once()
|
||||||
stats = mock_merge.call_args[0][0]
|
stats = mock_merge.call_args[0][0]
|
||||||
assert stats.provider_cost == 0.0
|
assert stats.provider_cost == 0.0
|
||||||
assert stats.provider_cost_type == "characters"
|
assert stats.provider_cost_type == "cost_usd"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from backend.data.model import (
|
|||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
|
NodeExecutionStats,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
@@ -431,6 +432,7 @@ class ClaudeCodeBlock(Block):
|
|||||||
# The JSON output contains the result
|
# The JSON output contains the result
|
||||||
output_data = json.loads(raw_output)
|
output_data = json.loads(raw_output)
|
||||||
response = output_data.get("result", raw_output)
|
response = output_data.get("result", raw_output)
|
||||||
|
self._record_cli_cost(output_data)
|
||||||
|
|
||||||
# Build conversation history entry
|
# Build conversation history entry
|
||||||
turn_entry = f"User: {prompt}\nClaude: {response}"
|
turn_entry = f"User: {prompt}\nClaude: {response}"
|
||||||
@@ -484,6 +486,23 @@ class ClaudeCodeBlock(Block):
|
|||||||
escaped = prompt.replace("'", "'\"'\"'")
|
escaped = prompt.replace("'", "'\"'\"'")
|
||||||
return f"'{escaped}'"
|
return f"'{escaped}'"
|
||||||
|
|
||||||
|
def _record_cli_cost(self, output_data: dict) -> None:
|
||||||
|
"""Feed Claude Code CLI's `total_cost_usd` to the COST_USD resolver.
|
||||||
|
|
||||||
|
The CLI rolls up Anthropic LLM + internal tool-call spend into
|
||||||
|
``total_cost_usd`` on its JSON response; piping it through
|
||||||
|
``merge_stats`` lets the wallet reflect real spend.
|
||||||
|
"""
|
||||||
|
total_cost_usd = output_data.get("total_cost_usd")
|
||||||
|
if total_cost_usd is None:
|
||||||
|
return
|
||||||
|
self.merge_stats(
|
||||||
|
NodeExecutionStats(
|
||||||
|
provider_cost=float(total_cost_usd),
|
||||||
|
provider_cost_type="cost_usd",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
|
|||||||
106
autogpt_platform/backend/backend/blocks/claude_code_cost_test.py
Normal file
106
autogpt_platform/backend/backend/blocks/claude_code_cost_test.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""Unit tests for ClaudeCodeBlock COST_USD billing migration.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Block emits provider_cost / cost_usd when Claude Code CLI returns
|
||||||
|
total_cost_usd.
|
||||||
|
- block_usage_cost resolves the COST_USD entry to the expected ceil(usd *
|
||||||
|
cost_amount) credit charge.
|
||||||
|
- Missing total_cost_usd gracefully produces provider_cost=None (no bill).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.blocks._base import BlockCostType
|
||||||
|
from backend.blocks.claude_code import ClaudeCodeBlock
|
||||||
|
from backend.data.block_cost_config import BLOCK_COSTS
|
||||||
|
from backend.data.model import NodeExecutionStats
|
||||||
|
from backend.executor.utils import block_usage_cost
|
||||||
|
|
||||||
|
|
||||||
|
def test_claude_code_registered_as_cost_usd_150():
|
||||||
|
"""Sanity: BLOCK_COSTS holds the COST_USD, 150 cr/$ entry."""
|
||||||
|
entries = BLOCK_COSTS[ClaudeCodeBlock]
|
||||||
|
assert len(entries) == 1
|
||||||
|
entry = entries[0]
|
||||||
|
assert entry.cost_type == BlockCostType.COST_USD
|
||||||
|
assert entry.cost_amount == 150
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"total_cost_usd, expected_credits",
|
||||||
|
[
|
||||||
|
(0.50, 75), # $0.50 × 150 = 75 cr
|
||||||
|
(1.00, 150), # $1.00 × 150 = 150 cr
|
||||||
|
(0.0134, 3), # ceil(0.0134 × 150) = ceil(2.01) = 3
|
||||||
|
(2.00, 300), # $2 × 150 = 300 cr
|
||||||
|
(0.001, 1), # ceil(0.001 × 150) = ceil(0.15) = 1 — no 0-cr leak on
|
||||||
|
# sub-cent runs
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_cost_usd_resolver_applies_150_multiplier(total_cost_usd, expected_credits):
|
||||||
|
"""block_usage_cost with cost_usd stats returns ceil(usd * 150)."""
|
||||||
|
block = ClaudeCodeBlock()
|
||||||
|
# cost_filter requires matching e2b_credentials; supply the ones the
|
||||||
|
# registration uses so _is_cost_filter_match accepts the input.
|
||||||
|
entry = BLOCK_COSTS[ClaudeCodeBlock][0]
|
||||||
|
input_data = {"e2b_credentials": entry.cost_filter["e2b_credentials"]}
|
||||||
|
stats = NodeExecutionStats(
|
||||||
|
provider_cost=total_cost_usd,
|
||||||
|
provider_cost_type="cost_usd",
|
||||||
|
)
|
||||||
|
cost, matching_filter = block_usage_cost(
|
||||||
|
block=block, input_data=input_data, stats=stats
|
||||||
|
)
|
||||||
|
assert cost == expected_credits
|
||||||
|
assert matching_filter == entry.cost_filter
|
||||||
|
|
||||||
|
|
||||||
|
def test_cost_usd_resolver_returns_zero_when_stats_missing_cost():
|
||||||
|
"""Pre-flight (no stats) or unbilled run (provider_cost None) → 0."""
|
||||||
|
block = ClaudeCodeBlock()
|
||||||
|
entry = BLOCK_COSTS[ClaudeCodeBlock][0]
|
||||||
|
input_data = {"e2b_credentials": entry.cost_filter["e2b_credentials"]}
|
||||||
|
# No stats at all → pre-flight path, returns 0.
|
||||||
|
pre_cost, _ = block_usage_cost(block=block, input_data=input_data)
|
||||||
|
assert pre_cost == 0
|
||||||
|
# Stats present but no provider_cost → resolver can't bill.
|
||||||
|
stats = NodeExecutionStats()
|
||||||
|
post_cost, _ = block_usage_cost(block=block, input_data=input_data, stats=stats)
|
||||||
|
assert post_cost == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_record_cli_cost_emits_provider_cost_when_total_cost_present():
|
||||||
|
"""``_record_cli_cost`` (the helper called from ``execute_claude_code``)
|
||||||
|
must emit a single ``merge_stats`` with provider_cost + cost_usd tag
|
||||||
|
when the CLI JSON payload carries ``total_cost_usd``.
|
||||||
|
"""
|
||||||
|
block = ClaudeCodeBlock()
|
||||||
|
captured: list[NodeExecutionStats] = []
|
||||||
|
with patch.object(block, "merge_stats", side_effect=captured.append):
|
||||||
|
block._record_cli_cost(
|
||||||
|
{
|
||||||
|
"result": "hello from claude",
|
||||||
|
"total_cost_usd": 0.0421,
|
||||||
|
"usage": {"input_tokens": 1234, "output_tokens": 56},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(captured) == 1
|
||||||
|
stats = captured[0]
|
||||||
|
assert stats.provider_cost == pytest.approx(0.0421)
|
||||||
|
assert stats.provider_cost_type == "cost_usd"
|
||||||
|
|
||||||
|
|
||||||
|
def test_record_cli_cost_skips_merge_when_total_cost_absent():
|
||||||
|
"""If the CLI payload lacks ``total_cost_usd`` (legacy / non-JSON
|
||||||
|
output), ``_record_cli_cost`` must not call ``merge_stats`` — otherwise
|
||||||
|
we'd pollute telemetry with a ``cost_usd`` emission that has no real
|
||||||
|
cost attached.
|
||||||
|
"""
|
||||||
|
block = ClaudeCodeBlock()
|
||||||
|
mock = MagicMock()
|
||||||
|
with patch.object(block, "merge_stats", mock):
|
||||||
|
block._record_cli_cost({"result": "hello"})
|
||||||
|
mock.assert_not_called()
|
||||||
@@ -151,6 +151,17 @@ class CodeGenerationBlock(Block):
|
|||||||
)
|
)
|
||||||
self.execution_stats = NodeExecutionStats()
|
self.execution_stats = NodeExecutionStats()
|
||||||
|
|
||||||
|
# GPT-5.1-Codex published pricing: $1.25 / 1M input, $10 / 1M output.
|
||||||
|
_INPUT_USD_PER_1M = 1.25
|
||||||
|
_OUTPUT_USD_PER_1M = 10.0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_token_usd(input_tokens: int, output_tokens: int) -> float:
|
||||||
|
return (
|
||||||
|
input_tokens * CodeGenerationBlock._INPUT_USD_PER_1M
|
||||||
|
+ output_tokens * CodeGenerationBlock._OUTPUT_USD_PER_1M
|
||||||
|
) / 1_000_000
|
||||||
|
|
||||||
async def call_codex(
|
async def call_codex(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -189,13 +200,15 @@ class CodeGenerationBlock(Block):
|
|||||||
response_id = response.id or ""
|
response_id = response.id or ""
|
||||||
|
|
||||||
# Update usage stats
|
# Update usage stats
|
||||||
self.execution_stats.input_token_count = (
|
input_tokens = response.usage.input_tokens if response.usage else 0
|
||||||
response.usage.input_tokens if response.usage else 0
|
output_tokens = response.usage.output_tokens if response.usage else 0
|
||||||
)
|
self.execution_stats.input_token_count = input_tokens
|
||||||
self.execution_stats.output_token_count = (
|
self.execution_stats.output_token_count = output_tokens
|
||||||
response.usage.output_tokens if response.usage else 0
|
|
||||||
)
|
|
||||||
self.execution_stats.llm_call_count += 1
|
self.execution_stats.llm_call_count += 1
|
||||||
|
self.execution_stats.provider_cost = self._compute_token_usd(
|
||||||
|
input_tokens, output_tokens
|
||||||
|
)
|
||||||
|
self.execution_stats.provider_cost_type = "cost_usd"
|
||||||
|
|
||||||
return CodexCallResult(
|
return CodexCallResult(
|
||||||
response=text_output,
|
response=text_output,
|
||||||
|
|||||||
10
autogpt_platform/backend/backend/blocks/compass/_config.py
Normal file
10
autogpt_platform/backend/backend/blocks/compass/_config.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""Provider registration for Compass — metadata only (auth lives elsewhere)."""
|
||||||
|
|
||||||
|
from backend.sdk import ProviderBuilder
|
||||||
|
|
||||||
|
compass = (
|
||||||
|
ProviderBuilder("compass")
|
||||||
|
.with_description("Geospatial context for agents")
|
||||||
|
.with_supported_auth_types("api_key")
|
||||||
|
.build()
|
||||||
|
)
|
||||||
226
autogpt_platform/backend/backend/blocks/cost_leak_fixes_test.py
Normal file
226
autogpt_platform/backend/backend/blocks/cost_leak_fixes_test.py
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
"""Coverage tests for the cost-leak fixes in this PR.
|
||||||
|
|
||||||
|
Each block's ``run()`` / helper emits provider_cost + cost_usd (or items)
|
||||||
|
via merge_stats so the post-flight resolver bills real provider spend.
|
||||||
|
Tests here drive that emission path directly so a regression on any one
|
||||||
|
block surfaces immediately.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.blocks._base import BlockCostType
|
||||||
|
from backend.blocks.ai_condition import AIConditionBlock
|
||||||
|
from backend.data.block_cost_config import BLOCK_COSTS, LLM_COST
|
||||||
|
from backend.data.model import APIKeyCredentials, NodeExecutionStats
|
||||||
|
|
||||||
|
# -------- AIConditionBlock registration --------
|
||||||
|
|
||||||
|
|
||||||
|
def test_ai_condition_registered_under_llm_cost():
|
||||||
|
"""AIConditionBlock was running wallet-free before this PR; verify it
|
||||||
|
now resolves through the same per-model LLM_COST table as every other
|
||||||
|
LLM block.
|
||||||
|
"""
|
||||||
|
assert BLOCK_COSTS[AIConditionBlock] is LLM_COST
|
||||||
|
|
||||||
|
|
||||||
|
# -------- Pinecone insert ITEMS emission --------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pinecone_insert_emits_items_provider_cost():
|
||||||
|
from backend.blocks.pinecone import PineconeInsertBlock
|
||||||
|
|
||||||
|
block = PineconeInsertBlock()
|
||||||
|
captured: list[NodeExecutionStats] = []
|
||||||
|
|
||||||
|
class _FakeIndex:
|
||||||
|
def upsert(self, **_):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class _FakePinecone:
|
||||||
|
def __init__(self, *_, **__):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def Index(self, _name):
|
||||||
|
return _FakeIndex()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("backend.blocks.pinecone.Pinecone", _FakePinecone),
|
||||||
|
patch.object(block, "merge_stats", side_effect=captured.append),
|
||||||
|
):
|
||||||
|
input_data = block.input_schema(
|
||||||
|
credentials={
|
||||||
|
"id": "00000000-0000-0000-0000-000000000000",
|
||||||
|
"provider": "pinecone",
|
||||||
|
"type": "api_key",
|
||||||
|
},
|
||||||
|
index="my-index",
|
||||||
|
chunks=["alpha", "beta", "gamma"],
|
||||||
|
embeddings=[[0.1] * 4, [0.2] * 4, [0.3] * 4],
|
||||||
|
namespace="",
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
|
||||||
|
creds = APIKeyCredentials(
|
||||||
|
id="00000000-0000-0000-0000-000000000000",
|
||||||
|
provider="pinecone",
|
||||||
|
title="mock",
|
||||||
|
api_key=SecretStr("mock-key"),
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
outputs = [(n, v) async for n, v in block.run(input_data, credentials=creds)]
|
||||||
|
|
||||||
|
assert any(name == "upsert_response" for name, _ in outputs)
|
||||||
|
assert len(captured) == 1
|
||||||
|
stats = captured[0]
|
||||||
|
assert stats.provider_cost == pytest.approx(3.0)
|
||||||
|
assert stats.provider_cost_type == "items"
|
||||||
|
|
||||||
|
|
||||||
|
# -------- Narration model-aware per-char rate --------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_id, expected_rate_per_char",
|
||||||
|
[
|
||||||
|
("eleven_flash_v2_5", 0.000167 * 0.5),
|
||||||
|
("eleven_turbo_v2_5", 0.000167 * 0.5),
|
||||||
|
("eleven_multilingual_v2", 0.000167 * 1.0),
|
||||||
|
("eleven_turbo_v2", 0.000167 * 1.0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_narration_per_char_rate_scales_with_model(model_id, expected_rate_per_char):
|
||||||
|
"""Drive VideoNarrationBlock._record_script_cost directly so a regression
|
||||||
|
that drops the model-aware branching (e.g. hardcoding 1.0 cr/char for
|
||||||
|
all models) makes this test fail.
|
||||||
|
"""
|
||||||
|
from backend.blocks.video.narration import VideoNarrationBlock
|
||||||
|
|
||||||
|
block = VideoNarrationBlock()
|
||||||
|
captured: list[NodeExecutionStats] = []
|
||||||
|
with patch.object(block, "merge_stats", side_effect=captured.append):
|
||||||
|
block._record_script_cost("x" * 5000, model_id)
|
||||||
|
|
||||||
|
assert len(captured) == 1
|
||||||
|
stats = captured[0]
|
||||||
|
assert stats.provider_cost == pytest.approx(5000 * expected_rate_per_char)
|
||||||
|
assert stats.provider_cost_type == "cost_usd"
|
||||||
|
|
||||||
|
|
||||||
|
# -------- Perplexity None-guard on x-total-cost --------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"openrouter_cost, expect_type",
|
||||||
|
[
|
||||||
|
(0.0421, "cost_usd"), # concrete positive USD → tagged
|
||||||
|
(None, None), # header missing → no tag (keeps gap observable)
|
||||||
|
(0.0, None), # zero → no tag (wouldn't bill anything anyway)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_perplexity_record_openrouter_cost_tags_only_on_concrete_value(
|
||||||
|
openrouter_cost, expect_type
|
||||||
|
):
|
||||||
|
"""Drive PerplexityBlock._record_openrouter_cost directly to verify the
|
||||||
|
None/0 guard. A regression that tags cost_usd unconditionally would
|
||||||
|
silently floor the user's bill to 0 via the resolver — this test
|
||||||
|
would catch it.
|
||||||
|
"""
|
||||||
|
from backend.blocks.perplexity import PerplexityBlock
|
||||||
|
|
||||||
|
block = PerplexityBlock()
|
||||||
|
with patch(
|
||||||
|
"backend.blocks.perplexity.extract_openrouter_cost",
|
||||||
|
return_value=openrouter_cost,
|
||||||
|
):
|
||||||
|
block._record_openrouter_cost(response=object())
|
||||||
|
|
||||||
|
assert block.execution_stats.provider_cost == openrouter_cost
|
||||||
|
assert block.execution_stats.provider_cost_type == expect_type
|
||||||
|
|
||||||
|
|
||||||
|
# -------- Codex COST_USD registration --------
|
||||||
|
|
||||||
|
|
||||||
|
def test_codex_registered_as_cost_usd_150():
|
||||||
|
from backend.blocks.codex import CodeGenerationBlock
|
||||||
|
|
||||||
|
entries = BLOCK_COSTS[CodeGenerationBlock]
|
||||||
|
assert len(entries) == 1
|
||||||
|
entry = entries[0]
|
||||||
|
assert entry.cost_type == BlockCostType.COST_USD
|
||||||
|
assert entry.cost_amount == 150
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_tokens, output_tokens, expected_usd",
|
||||||
|
[
|
||||||
|
# GPT-5.1-Codex: $1.25 / 1M input, $10 / 1M output.
|
||||||
|
(1_000_000, 0, 1.25),
|
||||||
|
(0, 1_000_000, 10.0),
|
||||||
|
(100_000, 10_000, 0.225), # 0.125 + 0.100
|
||||||
|
(0, 0, 0.0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_codex_computes_provider_cost_usd_from_token_counts(
|
||||||
|
input_tokens, output_tokens, expected_usd
|
||||||
|
):
|
||||||
|
"""Drive CodeGenerationBlock._compute_token_usd directly. A regression
|
||||||
|
to the wrong rate constants (e.g. swapping the $1.25 input rate for
|
||||||
|
GPT-4o's $2.50) would fail this test.
|
||||||
|
"""
|
||||||
|
from backend.blocks.codex import CodeGenerationBlock
|
||||||
|
|
||||||
|
assert CodeGenerationBlock._compute_token_usd(
|
||||||
|
input_tokens, output_tokens
|
||||||
|
) == pytest.approx(expected_usd)
|
||||||
|
|
||||||
|
|
||||||
|
# -------- ClaudeCode COST_USD registration sanity (already tested in claude_code_cost_test.py) --------
|
||||||
|
|
||||||
|
|
||||||
|
# -------- Perplexity COST_USD registration for all 3 tiers --------
|
||||||
|
|
||||||
|
|
||||||
|
def test_perplexity_sonar_all_tiers_registered_as_cost_usd_150():
|
||||||
|
from backend.blocks.perplexity import PerplexityBlock
|
||||||
|
|
||||||
|
entries = BLOCK_COSTS[PerplexityBlock]
|
||||||
|
# 3 tiers (SONAR, SONAR_PRO, SONAR_DEEP_RESEARCH) all COST_USD 150.
|
||||||
|
assert len(entries) == 3
|
||||||
|
for entry in entries:
|
||||||
|
assert entry.cost_type == BlockCostType.COST_USD
|
||||||
|
assert entry.cost_amount == 150
|
||||||
|
|
||||||
|
|
||||||
|
# -------- Narration COST_USD registration --------
|
||||||
|
|
||||||
|
|
||||||
|
def test_narration_registered_as_cost_usd_150():
|
||||||
|
from backend.blocks.video.narration import VideoNarrationBlock
|
||||||
|
|
||||||
|
entries = BLOCK_COSTS[VideoNarrationBlock]
|
||||||
|
assert len(entries) == 1
|
||||||
|
assert entries[0].cost_type == BlockCostType.COST_USD
|
||||||
|
assert entries[0].cost_amount == 150
|
||||||
|
|
||||||
|
|
||||||
|
# -------- Pinecone registrations --------
|
||||||
|
|
||||||
|
|
||||||
|
def test_pinecone_registrations():
|
||||||
|
from backend.blocks.pinecone import (
|
||||||
|
PineconeInitBlock,
|
||||||
|
PineconeInsertBlock,
|
||||||
|
PineconeQueryBlock,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert BLOCK_COSTS[PineconeInitBlock][0].cost_type == BlockCostType.RUN
|
||||||
|
assert BLOCK_COSTS[PineconeQueryBlock][0].cost_type == BlockCostType.RUN
|
||||||
|
# Insert scales with item count.
|
||||||
|
assert BLOCK_COSTS[PineconeInsertBlock][0].cost_type == BlockCostType.ITEMS
|
||||||
|
assert BLOCK_COSTS[PineconeInsertBlock][0].cost_amount == 1
|
||||||
@@ -19,6 +19,10 @@ class DataForSeoClient:
|
|||||||
trusted_origins=["https://api.dataforseo.com"],
|
trusted_origins=["https://api.dataforseo.com"],
|
||||||
raise_for_status=False,
|
raise_for_status=False,
|
||||||
)
|
)
|
||||||
|
# USD cost reported by DataForSEO on the most recent successful call.
|
||||||
|
# Populated by keyword_suggestions / related_keywords so the caller
|
||||||
|
# can surface it via NodeExecutionStats.provider_cost for billing.
|
||||||
|
self.last_cost_usd: float = 0.0
|
||||||
|
|
||||||
def _get_headers(self) -> Dict[str, str]:
|
def _get_headers(self) -> Dict[str, str]:
|
||||||
"""Generate the authorization header using Basic Auth."""
|
"""Generate the authorization header using Basic Auth."""
|
||||||
@@ -97,6 +101,9 @@ class DataForSeoClient:
|
|||||||
if data.get("tasks") and len(data["tasks"]) > 0:
|
if data.get("tasks") and len(data["tasks"]) > 0:
|
||||||
task = data["tasks"][0]
|
task = data["tasks"][0]
|
||||||
if task.get("status_code") == 20000: # Success code
|
if task.get("status_code") == 20000: # Success code
|
||||||
|
# DataForSEO reports per-task USD cost; stash it so callers
|
||||||
|
# can populate NodeExecutionStats.provider_cost.
|
||||||
|
self.last_cost_usd = float(task.get("cost") or 0.0)
|
||||||
return task.get("result", [])
|
return task.get("result", [])
|
||||||
else:
|
else:
|
||||||
error_msg = task.get("status_message", "Task failed")
|
error_msg = task.get("status_message", "Task failed")
|
||||||
@@ -174,6 +181,9 @@ class DataForSeoClient:
|
|||||||
if data.get("tasks") and len(data["tasks"]) > 0:
|
if data.get("tasks") and len(data["tasks"]) > 0:
|
||||||
task = data["tasks"][0]
|
task = data["tasks"][0]
|
||||||
if task.get("status_code") == 20000: # Success code
|
if task.get("status_code") == 20000: # Success code
|
||||||
|
# DataForSEO reports per-task USD cost; stash it so callers
|
||||||
|
# can populate NodeExecutionStats.provider_cost.
|
||||||
|
self.last_cost_usd = float(task.get("cost") or 0.0)
|
||||||
return task.get("result", [])
|
return task.get("result", [])
|
||||||
else:
|
else:
|
||||||
error_msg = task.get("status_message", "Task failed")
|
error_msg = task.get("status_message", "Task failed")
|
||||||
|
|||||||
@@ -7,11 +7,17 @@ from backend.sdk import BlockCostType, ProviderBuilder
|
|||||||
# Build the DataForSEO provider with username/password authentication
|
# Build the DataForSEO provider with username/password authentication
|
||||||
dataforseo = (
|
dataforseo = (
|
||||||
ProviderBuilder("dataforseo")
|
ProviderBuilder("dataforseo")
|
||||||
|
.with_description("SEO and SERP data")
|
||||||
.with_user_password(
|
.with_user_password(
|
||||||
username_env_var="DATAFORSEO_USERNAME",
|
username_env_var="DATAFORSEO_USERNAME",
|
||||||
password_env_var="DATAFORSEO_PASSWORD",
|
password_env_var="DATAFORSEO_PASSWORD",
|
||||||
title="DataForSEO Credentials",
|
title="DataForSEO Credentials",
|
||||||
)
|
)
|
||||||
.with_base_cost(1, BlockCostType.RUN)
|
# DataForSEO reports USD cost per task (e.g. $0.001/keyword returned).
|
||||||
|
# DataForSeoClient stashes it on last_cost_usd; each block emits it via
|
||||||
|
# merge_stats so the COST_USD resolver bills against real spend.
|
||||||
|
# 1000 platform credits per USD → 1 credit per $0.001 (≈ 1 credit/
|
||||||
|
# returned keyword on the standard tier).
|
||||||
|
.with_base_cost(1000, BlockCostType.COST_USD)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ DataForSEO Google Keyword Suggestions block.
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from backend.data.model import NodeExecutionStats
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
@@ -110,8 +111,10 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
|||||||
test_output=[
|
test_output=[
|
||||||
(
|
(
|
||||||
"suggestion",
|
"suggestion",
|
||||||
lambda x: hasattr(x, "keyword")
|
lambda x: (
|
||||||
and x.keyword == "digital marketing strategy",
|
hasattr(x, "keyword")
|
||||||
|
and x.keyword == "digital marketing strategy"
|
||||||
|
),
|
||||||
),
|
),
|
||||||
("suggestions", lambda x: isinstance(x, list) and len(x) == 1),
|
("suggestions", lambda x: isinstance(x, list) and len(x) == 1),
|
||||||
("total_count", 1),
|
("total_count", 1),
|
||||||
@@ -167,6 +170,16 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
|||||||
|
|
||||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||||
|
|
||||||
|
# DataForSEO reports per-task USD cost on the response. Feed it
|
||||||
|
# into NodeExecutionStats so the COST_USD resolver bills the
|
||||||
|
# real provider spend at reconciliation time.
|
||||||
|
self.merge_stats(
|
||||||
|
NodeExecutionStats(
|
||||||
|
provider_cost=client.last_cost_usd,
|
||||||
|
provider_cost_type="cost_usd",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Process and format the results
|
# Process and format the results
|
||||||
suggestions = []
|
suggestions = []
|
||||||
if results and len(results) > 0:
|
if results and len(results) > 0:
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ DataForSEO Google Related Keywords block.
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from backend.data.model import NodeExecutionStats
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
@@ -177,6 +178,16 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
|||||||
|
|
||||||
results = await self._fetch_related_keywords(client, input_data)
|
results = await self._fetch_related_keywords(client, input_data)
|
||||||
|
|
||||||
|
# DataForSEO reports per-task USD cost on the response. Feed it
|
||||||
|
# into NodeExecutionStats so the COST_USD resolver bills the
|
||||||
|
# real provider spend at reconciliation time.
|
||||||
|
self.merge_stats(
|
||||||
|
NodeExecutionStats(
|
||||||
|
provider_cost=client.last_cost_usd,
|
||||||
|
provider_cost_type="cost_usd",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Process and format the results
|
# Process and format the results
|
||||||
related_keywords = []
|
related_keywords = []
|
||||||
if results and len(results) > 0:
|
if results and len(results) > 0:
|
||||||
|
|||||||
10
autogpt_platform/backend/backend/blocks/discord/_config.py
Normal file
10
autogpt_platform/backend/backend/blocks/discord/_config.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""Provider registration for Discord — metadata only (auth lives in ``_auth.py``)."""
|
||||||
|
|
||||||
|
from backend.sdk import ProviderBuilder
|
||||||
|
|
||||||
|
discord = (
|
||||||
|
ProviderBuilder("discord")
|
||||||
|
.with_description("Messages, channels, and servers")
|
||||||
|
.with_supported_auth_types("api_key", "oauth2")
|
||||||
|
.build()
|
||||||
|
)
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
"""Provider registration for ElevenLabs — metadata only (auth lives in ``_auth.py``)."""
|
||||||
|
|
||||||
|
from backend.sdk import ProviderBuilder
|
||||||
|
|
||||||
|
elevenlabs = (
|
||||||
|
ProviderBuilder("elevenlabs")
|
||||||
|
.with_description("Realistic AI voice synthesis")
|
||||||
|
.with_supported_auth_types("api_key")
|
||||||
|
.build()
|
||||||
|
)
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
"""Provider registration for Enrichlayer — metadata only (auth lives in ``_auth.py``)."""
|
||||||
|
|
||||||
|
from backend.sdk import ProviderBuilder
|
||||||
|
|
||||||
|
enrichlayer = (
|
||||||
|
ProviderBuilder("enrichlayer")
|
||||||
|
.with_description("Enrich leads with company data")
|
||||||
|
.with_supported_auth_types("api_key")
|
||||||
|
.build()
|
||||||
|
)
|
||||||
@@ -9,8 +9,14 @@ from ._webhook import ExaWebhookManager
|
|||||||
# Configure the Exa provider once for all blocks
|
# Configure the Exa provider once for all blocks
|
||||||
exa = (
|
exa = (
|
||||||
ProviderBuilder("exa")
|
ProviderBuilder("exa")
|
||||||
|
.with_description("Neural web search")
|
||||||
.with_api_key("EXA_API_KEY", "Exa API Key")
|
.with_api_key("EXA_API_KEY", "Exa API Key")
|
||||||
.with_webhook_manager(ExaWebhookManager)
|
.with_webhook_manager(ExaWebhookManager)
|
||||||
.with_base_cost(1, BlockCostType.RUN)
|
# Exa returns `cost_dollars.total` on every response and ExaSearchBlock
|
||||||
|
# (plus ~45 sibling blocks that share this provider config) already
|
||||||
|
# populates NodeExecutionStats.provider_cost with it. Bill 100 credits
|
||||||
|
# per USD (~$0.01/credit): cheap searches stay at 1–2 credits, a Deep
|
||||||
|
# Research run at $0.20 lands at 20 credits, matching provider spend.
|
||||||
|
.with_base_cost(100, BlockCostType.COST_USD)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from backend.sdk import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from ._config import exa
|
from ._config import exa
|
||||||
|
from .helpers import merge_exa_cost
|
||||||
|
|
||||||
|
|
||||||
class AnswerCitation(BaseModel):
|
class AnswerCitation(BaseModel):
|
||||||
@@ -111,3 +112,7 @@ class ExaAnswerBlock(Block):
|
|||||||
yield "citations", citations
|
yield "citations", citations
|
||||||
for citation in citations:
|
for citation in citations:
|
||||||
yield "citation", citation
|
yield "citation", citation
|
||||||
|
|
||||||
|
# Current SDK AnswerResponse dataclass omits cost_dollars; helper
|
||||||
|
# no-ops today, but keeps billing wired when exa_py adds the field.
|
||||||
|
merge_exa_cost(self, response)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from typing import Union
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.model import NodeExecutionStats
|
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
@@ -23,6 +22,7 @@ from backend.sdk import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from ._config import exa
|
from ._config import exa
|
||||||
|
from .helpers import merge_exa_cost
|
||||||
|
|
||||||
|
|
||||||
class CodeContextResponse(BaseModel):
|
class CodeContextResponse(BaseModel):
|
||||||
@@ -118,9 +118,5 @@ class ExaCodeContextBlock(Block):
|
|||||||
yield "search_time", context.search_time
|
yield "search_time", context.search_time
|
||||||
yield "output_tokens", context.output_tokens
|
yield "output_tokens", context.output_tokens
|
||||||
|
|
||||||
# Parse cost_dollars (API returns as string, e.g. "0.005")
|
# API returns costDollars as a bare numeric string like "0.005".
|
||||||
try:
|
merge_exa_cost(self, data)
|
||||||
cost_usd = float(context.cost_dollars)
|
|
||||||
self.merge_stats(NodeExecutionStats(provider_cost=cost_usd))
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from typing import Optional
|
|||||||
from exa_py import AsyncExa
|
from exa_py import AsyncExa
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.model import NodeExecutionStats
|
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
@@ -24,6 +23,7 @@ from .helpers import (
|
|||||||
HighlightSettings,
|
HighlightSettings,
|
||||||
LivecrawlTypes,
|
LivecrawlTypes,
|
||||||
SummarySettings,
|
SummarySettings,
|
||||||
|
merge_exa_cost,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -224,6 +224,4 @@ class ExaContentsBlock(Block):
|
|||||||
|
|
||||||
if response.cost_dollars:
|
if response.cost_dollars:
|
||||||
yield "cost_dollars", response.cost_dollars
|
yield "cost_dollars", response.cost_dollars
|
||||||
self.merge_stats(
|
merge_exa_cost(self, response)
|
||||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -143,7 +143,9 @@ class TestExaContentsCostTracking:
|
|||||||
mock_exa_cls.return_value = mock_exa
|
mock_exa_cls.return_value = mock_exa
|
||||||
|
|
||||||
async for _ in block.run(
|
async for _ in block.run(
|
||||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
block.Input(
|
||||||
|
urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT
|
||||||
|
), # type: ignore[arg-type]
|
||||||
credentials=TEST_CREDENTIALS,
|
credentials=TEST_CREDENTIALS,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
@@ -172,7 +174,9 @@ class TestExaContentsCostTracking:
|
|||||||
mock_exa_cls.return_value = mock_exa
|
mock_exa_cls.return_value = mock_exa
|
||||||
|
|
||||||
async for _ in block.run(
|
async for _ in block.run(
|
||||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
block.Input(
|
||||||
|
urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT
|
||||||
|
), # type: ignore[arg-type]
|
||||||
credentials=TEST_CREDENTIALS,
|
credentials=TEST_CREDENTIALS,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
@@ -201,7 +205,9 @@ class TestExaContentsCostTracking:
|
|||||||
mock_exa_cls.return_value = mock_exa
|
mock_exa_cls.return_value = mock_exa
|
||||||
|
|
||||||
async for _ in block.run(
|
async for _ in block.run(
|
||||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
block.Input(
|
||||||
|
urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT
|
||||||
|
), # type: ignore[arg-type]
|
||||||
credentials=TEST_CREDENTIALS,
|
credentials=TEST_CREDENTIALS,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
@@ -297,7 +303,9 @@ class TestExaSimilarCostTracking:
|
|||||||
mock_exa_cls.return_value = mock_exa
|
mock_exa_cls.return_value = mock_exa
|
||||||
|
|
||||||
async for _ in block.run(
|
async for _ in block.run(
|
||||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
block.Input(
|
||||||
|
url="https://example.com", credentials=TEST_CREDENTIALS_INPUT
|
||||||
|
), # type: ignore[arg-type]
|
||||||
credentials=TEST_CREDENTIALS,
|
credentials=TEST_CREDENTIALS,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
@@ -326,7 +334,9 @@ class TestExaSimilarCostTracking:
|
|||||||
mock_exa_cls.return_value = mock_exa
|
mock_exa_cls.return_value = mock_exa
|
||||||
|
|
||||||
async for _ in block.run(
|
async for _ in block.run(
|
||||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
block.Input(
|
||||||
|
url="https://example.com", credentials=TEST_CREDENTIALS_INPUT
|
||||||
|
), # type: ignore[arg-type]
|
||||||
credentials=TEST_CREDENTIALS,
|
credentials=TEST_CREDENTIALS,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Literal, Optional, Union
|
from typing import Any, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
from backend.sdk import BaseModel, MediaFileType, SchemaField
|
from backend.data.model import NodeExecutionStats
|
||||||
|
from backend.sdk import BaseModel, Block, MediaFileType, SchemaField
|
||||||
|
|
||||||
|
|
||||||
class LivecrawlTypes(str, Enum):
|
class LivecrawlTypes(str, Enum):
|
||||||
@@ -319,7 +320,7 @@ class CostDollars(BaseModel):
|
|||||||
|
|
||||||
# Helper functions for payload processing
|
# Helper functions for payload processing
|
||||||
def process_text_field(
|
def process_text_field(
|
||||||
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None]
|
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None],
|
||||||
) -> Optional[Union[bool, Dict[str, Any]]]:
|
) -> Optional[Union[bool, Dict[str, Any]]]:
|
||||||
"""Process text field for API payload."""
|
"""Process text field for API payload."""
|
||||||
if text is None:
|
if text is None:
|
||||||
@@ -400,7 +401,7 @@ def process_contents_settings(contents: Optional[ContentSettings]) -> Dict[str,
|
|||||||
|
|
||||||
|
|
||||||
def process_context_field(
|
def process_context_field(
|
||||||
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None]
|
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None],
|
||||||
) -> Optional[Union[bool, Dict[str, int]]]:
|
) -> Optional[Union[bool, Dict[str, int]]]:
|
||||||
"""Process context field for API payload."""
|
"""Process context field for API payload."""
|
||||||
if context is None:
|
if context is None:
|
||||||
@@ -448,3 +449,65 @@ def add_optional_fields(
|
|||||||
payload[api_field] = value.value
|
payload[api_field] = value.value
|
||||||
else:
|
else:
|
||||||
payload[api_field] = value
|
payload[api_field] = value
|
||||||
|
|
||||||
|
|
||||||
|
def extract_exa_cost_usd(response: Any) -> Optional[float]:
|
||||||
|
"""Return ``cost_dollars.total`` (USD) from an Exa SDK response, or None.
|
||||||
|
|
||||||
|
Handles dataclass/pydantic responses (``response.cost_dollars.total``),
|
||||||
|
dicts with camelCase keys (``response["costDollars"]["total"]``), dicts
|
||||||
|
with snake_case keys, and bare numeric strings. Returns None whenever the
|
||||||
|
shape is missing cost info — the caller then skips merge_stats.
|
||||||
|
"""
|
||||||
|
if response is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Dataclass / pydantic: response.cost_dollars
|
||||||
|
cost_obj = getattr(response, "cost_dollars", None)
|
||||||
|
|
||||||
|
# Dict payloads: try both camelCase and snake_case
|
||||||
|
if cost_obj is None and isinstance(response, dict):
|
||||||
|
cost_obj = response.get("costDollars") or response.get("cost_dollars")
|
||||||
|
|
||||||
|
if cost_obj is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Already a scalar (code_context endpoint returns a string)
|
||||||
|
if isinstance(cost_obj, (int, float)):
|
||||||
|
return max(0.0, float(cost_obj))
|
||||||
|
if isinstance(cost_obj, str):
|
||||||
|
try:
|
||||||
|
return max(0.0, float(cost_obj))
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Nested object/dict: grab the `total` field
|
||||||
|
total = getattr(cost_obj, "total", None)
|
||||||
|
if total is None and isinstance(cost_obj, dict):
|
||||||
|
total = cost_obj.get("total")
|
||||||
|
|
||||||
|
if total is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return max(0.0, float(total))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def merge_exa_cost(block: Block, response: Any) -> None:
|
||||||
|
"""Pull ``cost_dollars.total`` off an Exa response and merge it into stats.
|
||||||
|
|
||||||
|
No-op when the response shape has no cost info (e.g. webset CRUD where
|
||||||
|
the SDK does not expose per-call pricing) — emission happens only when
|
||||||
|
Exa actually reports a USD amount.
|
||||||
|
"""
|
||||||
|
cost_usd = extract_exa_cost_usd(response)
|
||||||
|
if cost_usd is None:
|
||||||
|
return
|
||||||
|
block.merge_stats(
|
||||||
|
NodeExecutionStats(
|
||||||
|
provider_cost=cost_usd,
|
||||||
|
provider_cost_type="cost_usd",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@@ -0,0 +1,65 @@
|
|||||||
|
"""Unit tests for exa/helpers cost-extraction + merge helpers."""
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.blocks.exa.helpers import extract_exa_cost_usd, merge_exa_cost
|
||||||
|
from backend.data.model import NodeExecutionStats
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"response, expected",
|
||||||
|
[
|
||||||
|
# Dataclass / SimpleNamespace with cost_dollars.total
|
||||||
|
(SimpleNamespace(cost_dollars=SimpleNamespace(total=0.05)), 0.05),
|
||||||
|
# Dict camelCase
|
||||||
|
({"costDollars": {"total": 0.10}}, 0.10),
|
||||||
|
# Dict snake_case
|
||||||
|
({"cost_dollars": {"total": 0.07}}, 0.07),
|
||||||
|
# code_context endpoint shape: plain numeric string
|
||||||
|
(SimpleNamespace(cost_dollars="0.005"), 0.005),
|
||||||
|
# Scalar float on cost_dollars directly
|
||||||
|
(SimpleNamespace(cost_dollars=0.02), 0.02),
|
||||||
|
# Scalar int on cost_dollars
|
||||||
|
(SimpleNamespace(cost_dollars=3), 3.0),
|
||||||
|
# Missing cost info — returns None
|
||||||
|
({}, None),
|
||||||
|
(SimpleNamespace(other="foo"), None),
|
||||||
|
(None, None),
|
||||||
|
# Nested total=None
|
||||||
|
(SimpleNamespace(cost_dollars=SimpleNamespace(total=None)), None),
|
||||||
|
# Invalid numeric string
|
||||||
|
(SimpleNamespace(cost_dollars="not-a-number"), None),
|
||||||
|
# Negative values clamp to 0
|
||||||
|
(SimpleNamespace(cost_dollars=SimpleNamespace(total=-1.0)), 0.0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_exa_cost_usd_handles_all_shapes(response, expected):
|
||||||
|
assert extract_exa_cost_usd(response) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_exa_cost_emits_stats_when_cost_present():
|
||||||
|
block = MagicMock()
|
||||||
|
response = SimpleNamespace(cost_dollars=SimpleNamespace(total=0.0421))
|
||||||
|
merge_exa_cost(block, response)
|
||||||
|
|
||||||
|
block.merge_stats.assert_called_once()
|
||||||
|
stats: NodeExecutionStats = block.merge_stats.call_args.args[0]
|
||||||
|
assert stats.provider_cost == pytest.approx(0.0421)
|
||||||
|
assert stats.provider_cost_type == "cost_usd"
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_exa_cost_noops_when_no_cost():
|
||||||
|
"""Webset CRUD endpoints don't surface cost_dollars today — the helper
|
||||||
|
must silently skip instead of emitting a 0-cost telemetry record."""
|
||||||
|
block = MagicMock()
|
||||||
|
merge_exa_cost(block, SimpleNamespace(other_field="nothing"))
|
||||||
|
block.merge_stats.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_exa_cost_noops_when_response_is_none():
|
||||||
|
block = MagicMock()
|
||||||
|
merge_exa_cost(block, None)
|
||||||
|
block.merge_stats.assert_not_called()
|
||||||
@@ -12,7 +12,6 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.model import NodeExecutionStats
|
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
@@ -26,6 +25,7 @@ from backend.sdk import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from ._config import exa
|
from ._config import exa
|
||||||
|
from .helpers import merge_exa_cost
|
||||||
|
|
||||||
|
|
||||||
class ResearchModel(str, Enum):
|
class ResearchModel(str, Enum):
|
||||||
@@ -233,11 +233,7 @@ class ExaCreateResearchBlock(Block):
|
|||||||
|
|
||||||
if research.cost_dollars:
|
if research.cost_dollars:
|
||||||
yield "cost_total", research.cost_dollars.total
|
yield "cost_total", research.cost_dollars.total
|
||||||
self.merge_stats(
|
merge_exa_cost(self, research)
|
||||||
NodeExecutionStats(
|
|
||||||
provider_cost=research.cost_dollars.total
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
await asyncio.sleep(check_interval)
|
await asyncio.sleep(check_interval)
|
||||||
@@ -352,9 +348,7 @@ class ExaGetResearchBlock(Block):
|
|||||||
yield "cost_searches", research.cost_dollars.num_searches
|
yield "cost_searches", research.cost_dollars.num_searches
|
||||||
yield "cost_pages", research.cost_dollars.num_pages
|
yield "cost_pages", research.cost_dollars.num_pages
|
||||||
yield "cost_reasoning_tokens", research.cost_dollars.reasoning_tokens
|
yield "cost_reasoning_tokens", research.cost_dollars.reasoning_tokens
|
||||||
self.merge_stats(
|
merge_exa_cost(self, research)
|
||||||
NodeExecutionStats(provider_cost=research.cost_dollars.total)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "error_message", research.error
|
yield "error_message", research.error
|
||||||
|
|
||||||
@@ -441,9 +435,7 @@ class ExaWaitForResearchBlock(Block):
|
|||||||
|
|
||||||
if research.cost_dollars:
|
if research.cost_dollars:
|
||||||
yield "cost_total", research.cost_dollars.total
|
yield "cost_total", research.cost_dollars.total
|
||||||
self.merge_stats(
|
merge_exa_cost(self, research)
|
||||||
NodeExecutionStats(provider_cost=research.cost_dollars.total)
|
|
||||||
)
|
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user