mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Compare commits
36 Commits
ci/cla-lab
...
test-scree
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9a2373bf61 | ||
|
|
63c4229774 | ||
|
|
c0a27ab878 | ||
|
|
08b568021b | ||
|
|
316b132a13 | ||
|
|
db25bbf47d | ||
|
|
2517dae85a | ||
|
|
080d42b9da | ||
|
|
3d7b381620 | ||
|
|
02be5440fc | ||
|
|
e17e9f13c4 | ||
|
|
f238c153a5 | ||
|
|
01f1289aac | ||
|
|
343222ace1 | ||
|
|
a8226af725 | ||
|
|
f06b5293de | ||
|
|
70b591d74f | ||
|
|
b1c043c2d8 | ||
|
|
fcaebd1bb7 | ||
|
|
3a01874911 | ||
|
|
6d770d9917 | ||
|
|
334ec18c31 | ||
|
|
ea5cfdfa2e | ||
|
|
d13a85bef7 | ||
|
|
60b85640e7 | ||
|
|
87e4d42750 | ||
|
|
0339d95d12 | ||
|
|
f410929560 | ||
|
|
2bbec09e1a | ||
|
|
31b88a6e56 | ||
|
|
d357956d98 | ||
|
|
697ffa81f0 | ||
|
|
2b4727e8b2 | ||
|
|
0d4b31e8a1 | ||
|
|
0cd0a76305 | ||
|
|
d01a51be0e |
@@ -25,6 +25,8 @@ Understand the **Why / What / How** before addressing comments — you need cont
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
```
|
||||
|
||||
> If GraphQL is rate-limited, `gh pr view` fails. See [GitHub rate limits](#github-rate-limits) for REST fallbacks.
|
||||
|
||||
## Fetch comments (all sources)
|
||||
|
||||
### 1. Inline review threads — GraphQL (primary source of actionable items)
|
||||
@@ -109,12 +111,16 @@ Only after this loop completes (all pages fetched, count confirmed) should you b
|
||||
|
||||
**Filter to unresolved threads only** — skip any thread where `isResolved: true`. `comments(last: 1)` returns the most recent comment in the thread — act on that; it reflects the reviewer's final ask. Use the thread `id` (Relay global ID) to track threads across polls.
|
||||
|
||||
> If GraphQL is rate-limited, see [GitHub rate limits](#github-rate-limits) for the REST fallback (flat comment list — no thread grouping or `isResolved`).
|
||||
|
||||
### 2. Top-level reviews — REST (MUST paginate)
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
|
||||
```
|
||||
|
||||
> **Already REST — unaffected by GraphQL rate limits or outages. Continue polling reviews normally even when GraphQL is exhausted.**
|
||||
|
||||
**CRITICAL — always `--paginate`.** Reviews default to 30 per page. PRs can have 80–170+ reviews (mostly empty resolution events). Without pagination you miss reviews past position 30 — including `autogpt-reviewer`'s structured review which is typically posted after several CI runs and sits well beyond the first page.
|
||||
|
||||
Two things to extract:
|
||||
@@ -133,6 +139,8 @@ Two things to extract:
|
||||
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
|
||||
```
|
||||
|
||||
> **Already REST — unaffected by GraphQL rate limits.**
|
||||
|
||||
Mostly contains: bot summaries (`coderabbitai[bot]`), CI/conflict detection (`github-actions[bot]`), and author status updates. Scan for non-empty messages from non-bot human reviewers that aren't the PR author — those are the ones that need a response.
|
||||
|
||||
## For each unaddressed comment
|
||||
@@ -327,18 +335,65 @@ git push
|
||||
|
||||
5. Restart the polling loop from the top — new commits reset CI status.
|
||||
|
||||
## GitHub abuse rate limits
|
||||
## GitHub rate limits
|
||||
|
||||
Two distinct rate limits exist — they have different causes and recovery times:
|
||||
Three distinct rate limits exist — they have different causes, error shapes, and recovery times:
|
||||
|
||||
| Error | HTTP code | Cause | Recovery |
|
||||
|---|---|---|---|
|
||||
| `{"code":"abuse"}` | 403 | Secondary rate limit — too many write operations (comments, mutations) in a short window | Wait **2–3 minutes**. 60s is often not enough. |
|
||||
| `{"message":"API rate limit exceeded"}` | 429 | Primary rate limit — too many API calls per hour | Wait until `X-RateLimit-Reset` header timestamp |
|
||||
| `{"message":"API rate limit exceeded"}` | 429 | Primary REST rate limit — 5000 calls/hr per user | Wait until `X-RateLimit-Reset` header timestamp |
|
||||
| `GraphQL: API rate limit already exceeded for user ID ...` | 403 on stderr, `gh` exits 1 | **GraphQL-specific** per-user limit — distinct from REST's 5000/hr and from the abuse secondary limit. Trips faster than REST because point costs per query. | Wait until the GraphQL window resets (typically ~1 hour from the first call in the window). REST still works — use fallbacks below. |
|
||||
|
||||
**Prevention:** Add `sleep 3` between individual thread reply API calls. When posting >20 replies, increase to `sleep 5`.
|
||||
|
||||
**Recovery from secondary rate limit (403):**
|
||||
### Detection
|
||||
|
||||
The `gh` CLI surfaces the GraphQL limit on stderr with the exact string `GraphQL: API rate limit already exceeded for user ID <id>` and exits 1 — any `gh api graphql ...` **or** `gh pr view ...` call fails. Check current quota and reset time via the REST endpoint that reports GraphQL quota (this call is REST and still works whether GraphQL is rate-limited OR fully down):
|
||||
|
||||
```bash
|
||||
gh api rate_limit --jq '.resources.graphql' # { "limit": 5000, "used": 5000, "remaining": 0, "reset": 1729...}
|
||||
# Human-readable reset:
|
||||
gh api rate_limit --jq '.resources.graphql.reset' | xargs -I{} date -r {}
|
||||
```
|
||||
|
||||
Retry when `remaining > 0`. If you need to proceed sooner, sleep 2–5 min and probe again — the limit is per user, not per machine, so other concurrent agents under the same token also consume it.
|
||||
|
||||
### What keeps working
|
||||
|
||||
When GraphQL is unavailable (rate-limited or outage):
|
||||
|
||||
- **Keeps working (REST):** top-level reviews fetch, conversation comments fetch, all inline-comment replies, CI status (`gh pr checks`), and the `gh api rate_limit` probe.
|
||||
- **Degraded:** inline thread list — fall back to flat `/pulls/{N}/comments` REST, which drops thread grouping, `isResolved`, and Relay thread IDs. You still get comment bodies and the `databaseId` as `id`, enough to read and reply.
|
||||
- **Blocked:** `gh pr view`, the `resolveReviewThread` mutation, and any new `gh api graphql` queries — wait for the quota to reset.
|
||||
|
||||
### Fall back to REST
|
||||
|
||||
**PR metadata reads** — `gh pr view` uses GraphQL under the hood; use the REST pulls endpoint instead, which returns the full PR object:
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.body' # == --json body
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.base.ref' # == --json baseRefName
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.mergeable' # == --json mergeable
|
||||
```
|
||||
|
||||
Note: REST `mergeable` returns `true|false|null`; GraphQL returns `MERGEABLE|CONFLICTING|UNKNOWN`. The `null` case maps to `UNKNOWN` — treat it the same (still computing; poll again).
|
||||
|
||||
**Inline comments (flat list)** — no thread grouping or `isResolved`, but enough to read and reply:
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate \
|
||||
| jq '[.[] | {id, path, line, user: .user.login, body: .body[:200], in_reply_to_id}]'
|
||||
```
|
||||
|
||||
Use this degraded mode to make progress on the fix → reply loop, then return to GraphQL for `resolveReviewThread` once the rate limit resets.
|
||||
|
||||
**Replies** — already REST-native (`/pulls/{N}/comments/{ID}/replies`); no change needed, use the same command as the main flow.
|
||||
|
||||
**`resolveReviewThread`** — **no REST equivalent**; GitHub does not expose a REST endpoint for thread resolution. Queue the thread IDs needing resolution, wait for the GraphQL limit to reset, then run the resolve mutations in a batch (with `sleep 3` between calls, per the secondary-limit guidance).
|
||||
|
||||
### Recovery from secondary rate limit (403 abuse)
|
||||
|
||||
1. Stop all API writes immediately
|
||||
2. Wait **2 minutes minimum** (not 60s — secondary limits are stricter)
|
||||
3. Resume with `sleep 3` between each call
|
||||
@@ -397,6 +452,8 @@ gh api graphql -f query='mutation { resolveReviewThread(input: {threadId: "THREA
|
||||
|
||||
**Never call this mutation before committing the fix.** The orchestrator will verify actual unresolved counts via GraphQL after you output `ORCHESTRATOR:DONE` — false resolutions will be caught and you will be re-briefed.
|
||||
|
||||
> `resolveReviewThread` is GraphQL-only — no REST equivalent. If GraphQL is rate-limited, see [GitHub rate limits](#github-rate-limits) for the queue-and-retry flow.
|
||||
|
||||
### Verify actual count before outputting ORCHESTRATOR:DONE
|
||||
|
||||
Before claiming "0 unresolved threads", always query GitHub directly — don't rely on your own bookkeeping. Paginate all pages — a single `first: 100` query misses threads beyond page 1:
|
||||
|
||||
@@ -5,7 +5,7 @@ user-invocable: true
|
||||
argument-hint: "[worktree path or PR number] — tests the PR in the given worktree. Optional flags: --fix (auto-fix issues found)"
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "2.0.0"
|
||||
version: "2.1.0"
|
||||
---
|
||||
|
||||
# Manual E2E Test
|
||||
@@ -180,6 +180,94 @@ Based on the PR analysis, write a test plan to `$RESULTS_DIR/test-plan.md`:
|
||||
|
||||
**Be critical** — include edge cases, error paths, and security checks. Every scenario MUST specify what screenshots to take and what state to verify.
|
||||
|
||||
## Step 3.0: Claim the testing lock (coordinate parallel agents)
|
||||
|
||||
Multiple worktrees share the same host — Docker infra (postgres, redis, clamav), app ports (3000/8006/…), and the test user. Two agents running `/pr-test` concurrently will corrupt each other's state (connection-pool exhaustion, port binds failing silently, cross-test assertions). Use the root-worktree lock file to take turns.
|
||||
|
||||
### Lock file contract
|
||||
|
||||
Path (**always** the root worktree so all siblings see it): `/Users/majdyz/Code/AutoGPT/.ign.testing.lock`
|
||||
|
||||
Body (one `key=value` per line):
|
||||
```
|
||||
holder=<pr-XXXXX-purpose>
|
||||
pid=<pid-or-"self">
|
||||
started=<iso8601>
|
||||
heartbeat=<iso8601, updated every ~2 min>
|
||||
worktree=<full path>
|
||||
branch=<branch name>
|
||||
intent=<one-line description + rough duration>
|
||||
```
|
||||
|
||||
### Claim
|
||||
|
||||
```bash
|
||||
LOCK=/Users/majdyz/Code/AutoGPT/.ign.testing.lock
|
||||
NOW=$(date -u +%Y-%m-%dT%H:%MZ)
|
||||
STALE_AFTER_MIN=5
|
||||
|
||||
if [ -f "$LOCK" ]; then
|
||||
HB=$(grep '^heartbeat=' "$LOCK" | cut -d= -f2)
|
||||
HB_EPOCH=$(date -j -f '%Y-%m-%dT%H:%MZ' "$HB" +%s 2>/dev/null || date -d "$HB" +%s 2>/dev/null || echo 0)
|
||||
AGE_MIN=$(( ( $(date -u +%s) - HB_EPOCH ) / 60 ))
|
||||
if [ "$AGE_MIN" -gt "$STALE_AFTER_MIN" ]; then
|
||||
echo "WARN: stale lock (${AGE_MIN}m old) — reclaiming"
|
||||
cat "$LOCK" | sed 's/^/ stale: /'
|
||||
else
|
||||
echo "Another agent holds the lock:"; cat "$LOCK"
|
||||
echo "Wait until released or resume after $((STALE_AFTER_MIN - AGE_MIN))m."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
cat > "$LOCK" <<EOF
|
||||
holder=pr-${PR_NUMBER}-e2e
|
||||
pid=self
|
||||
started=$NOW
|
||||
heartbeat=$NOW
|
||||
worktree=$WORKTREE_PATH
|
||||
branch=$(cd $WORKTREE_PATH && git branch --show-current)
|
||||
intent=E2E test PR #${PR_NUMBER}, native mode, ~60min
|
||||
EOF
|
||||
echo "Lock claimed"
|
||||
```
|
||||
|
||||
### Heartbeat (MUST run in background during the whole test)
|
||||
|
||||
Without a heartbeat a crashed agent keeps the lock forever. Run this as a background process right after claim:
|
||||
|
||||
```bash
|
||||
(while true; do
|
||||
sleep 120
|
||||
[ -f "$LOCK" ] || exit 0 # lock released → exit heartbeat
|
||||
perl -i -pe "s/^heartbeat=.*/heartbeat=$(date -u +%Y-%m-%dT%H:%MZ)/" "$LOCK"
|
||||
done) &
|
||||
HEARTBEAT_PID=$!
|
||||
echo "$HEARTBEAT_PID" > /tmp/pr-test-heartbeat.pid
|
||||
```
|
||||
|
||||
### Release (always — even on failure)
|
||||
|
||||
```bash
|
||||
kill "$HEARTBEAT_PID" 2>/dev/null
|
||||
rm -f "$LOCK" /tmp/pr-test-heartbeat.pid
|
||||
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock" \
|
||||
>> /Users/majdyz/Code/AutoGPT/.ign.testing.log
|
||||
```
|
||||
|
||||
Use a `trap` so release runs even on `exit 1`:
|
||||
```bash
|
||||
trap 'kill "$HEARTBEAT_PID" 2>/dev/null; rm -f "$LOCK"' EXIT INT TERM
|
||||
```
|
||||
|
||||
### Shared status log
|
||||
|
||||
`/Users/majdyz/Code/AutoGPT/.ign.testing.log` is an append-only channel any agent can read/write. Use it for "I'm waiting", "I'm done, resources free", or post-run notes:
|
||||
```bash
|
||||
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] <message>" \
|
||||
>> /Users/majdyz/Code/AutoGPT/.ign.testing.log
|
||||
```
|
||||
|
||||
## Step 3: Environment setup
|
||||
|
||||
### 3a. Copy .env files from the root worktree
|
||||
@@ -248,7 +336,87 @@ docker ps --format "{{.Names}}" | grep -E "rest_server|executor|copilot|websocke
|
||||
done
|
||||
```
|
||||
|
||||
### 3e. Build and start
|
||||
**Native mode also:** when running the app natively (see 3e-native), kill any stray host processes and free the app ports before starting — otherwise `poetry run app` and `pnpm dev` will fail to bind.
|
||||
|
||||
```bash
|
||||
# Kill stray native app processes from prior runs
|
||||
pkill -9 -f "python.*backend" 2>/dev/null || true
|
||||
pkill -9 -f "poetry run app" 2>/dev/null || true
|
||||
pkill -9 -f "next-server|next dev" 2>/dev/null || true
|
||||
|
||||
# Free app ports (errors per port are ignored — port may simply be unused)
|
||||
for port in 3000 8006 8001 8002 8005 8008; do
|
||||
lsof -ti :$port -sTCP:LISTEN | xargs -r kill -9 2>/dev/null || true
|
||||
done
|
||||
```
|
||||
|
||||
### 3e-native. Run the app natively (PREFERRED for iterative dev)
|
||||
|
||||
Native mode runs infra (postgres, supabase, redis, rabbitmq, clamav) in docker but runs the backend and frontend directly on the host. This avoids the 3-8 minute `docker compose build` cycle on every backend change — code edits are picked up on process restart (seconds) instead of a full image rebuild.
|
||||
|
||||
**When to prefer native mode (default for this skill):**
|
||||
- Iterative dev/debug loops where you're editing backend or frontend code between test runs
|
||||
- Any PR that touches Python/TS source but not Dockerfiles, compose config, or infra images
|
||||
- Fast repro of a failing scenario — restart `poetry run app` in a couple of seconds
|
||||
|
||||
**When to prefer docker mode (3e fallback):**
|
||||
- Testing changes to `Dockerfile`, `docker-compose.yml`, or base images
|
||||
- Production-parity smoke tests (exact container env, networking, volumes)
|
||||
- CI-equivalent runs where you need the exact image that'll ship
|
||||
|
||||
**Note on 3b (copilot auth):** no npm install anywhere. `poetry install` pulls in `claude_agent_sdk`, which ships its own Claude CLI binary — available on `PATH` whenever you run commands via `poetry run` (native) OR whenever the copilot_executor container is built from its Poetry lockfile (docker). The OAuth token extraction still applies (same `refresh_claude_token.sh` call).
|
||||
|
||||
**Preamble:** before starting native, run the kill-stray + free-ports block from 3c's "Native mode also" subsection.
|
||||
|
||||
**1. Start infra only (one-time per session):**
|
||||
|
||||
```bash
|
||||
cd $PLATFORM_DIR && docker compose --profile local up deps --detach --remove-orphans --build
|
||||
```
|
||||
|
||||
This brings up postgres/supabase/redis/rabbitmq/clamav and skips all app services.
|
||||
|
||||
**2. Start the backend natively:**
|
||||
|
||||
```bash
|
||||
cd $BACKEND_DIR && (poetry run app 2>&1 | tee .ign.application.logs) &
|
||||
```
|
||||
|
||||
`poetry run app` spawns **all** app subprocesses — `rest_server`, `executor`, `copilot_executor`, `websocket`, `scheduler`, `notification_server`, `database_manager` — inside ONE parent process. No separate containers, no separate terminals. The `.ign.application.logs` prefix is already gitignored.
|
||||
|
||||
**3. Wait for the backend on :8006 BEFORE starting the frontend.** This ordering matters — the frontend's `pnpm dev` startup invokes `generate-api-queries`, which fetches `/openapi.json` from the backend. If the backend isn't listening yet, `pnpm dev` fails immediately.
|
||||
|
||||
```bash
|
||||
for i in $(seq 1 60); do
|
||||
if [ "$(curl -s -o /dev/null -w '%{http_code}' http://localhost:8006/docs 2>/dev/null)" = "200" ]; then
|
||||
echo "Backend ready"
|
||||
break
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
```
|
||||
|
||||
**4. Start the frontend natively:**
|
||||
|
||||
```bash
|
||||
cd $FRONTEND_DIR && (pnpm dev 2>&1 | tee .ign.frontend.logs) &
|
||||
```
|
||||
|
||||
**5. Wait for the frontend on :3000:**
|
||||
|
||||
```bash
|
||||
for i in $(seq 1 60); do
|
||||
if [ "$(curl -s -o /dev/null -w '%{http_code}' http://localhost:3000 2>/dev/null)" = "200" ]; then
|
||||
echo "Frontend ready"
|
||||
break
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
```
|
||||
|
||||
Once both are up, skip 3e/3f and go straight to **3g/3h** (feature flags / test user creation).
|
||||
|
||||
### 3e. Build and start (docker — fallback)
|
||||
|
||||
```bash
|
||||
cd $PLATFORM_DIR && docker compose build --no-cache 2>&1 | tail -20
|
||||
@@ -442,6 +610,22 @@ agent-browser --session-name pr-test snapshot | grep "text:"
|
||||
|
||||
### Checking logs
|
||||
|
||||
**Native mode:** when running via `poetry run app` + `pnpm dev`, all app logs stream to the `.ign.*.logs` files written by the `tee` pipes in 3e-native. `rest_server`, `executor`, `copilot_executor`, `websocket`, `scheduler`, `notification_server`, and `database_manager` are all subprocesses of the single `poetry run app` parent, so their output is interleaved in `.ign.application.logs`.
|
||||
|
||||
```bash
|
||||
# Backend (all app subprocesses interleaved)
|
||||
tail -f $BACKEND_DIR/.ign.application.logs
|
||||
|
||||
# Frontend (Next.js dev server)
|
||||
tail -f $FRONTEND_DIR/.ign.frontend.logs
|
||||
|
||||
# Filter for errors across either log
|
||||
grep -iE "error|exception|traceback" $BACKEND_DIR/.ign.application.logs | tail -20
|
||||
grep -iE "error|exception|traceback" $FRONTEND_DIR/.ign.frontend.logs | tail -20
|
||||
```
|
||||
|
||||
**Docker mode:**
|
||||
|
||||
```bash
|
||||
# Backend REST server
|
||||
docker logs autogpt_platform-rest_server-1 2>&1 | tail -30
|
||||
@@ -876,9 +1060,15 @@ test scenario → find issue (bug OR UX problem) → screenshot broken state
|
||||
### Problem: Frontend shows cookie banner blocking interaction
|
||||
**Fix:** `agent-browser click 'text=Accept All'` before other interactions.
|
||||
|
||||
### Problem: Container loses npm packages after rebuild
|
||||
**Cause:** `docker compose up --build` rebuilds the image, losing runtime installs.
|
||||
**Fix:** Add packages to the Dockerfile instead of installing at runtime.
|
||||
### Problem: Claude CLI not found in copilot_executor container
|
||||
**Symptom:** Copilot logs say `claude: command not found` or similar when starting an SDK turn.
|
||||
**Cause:** Image was built without `poetry install` (stale base layer, or Dockerfile bypass). The SDK CLI ships inside the `claude_agent_sdk` Poetry dep — it is NOT an npm package.
|
||||
**Fix:** Rebuild the image cleanly: `docker compose build --no-cache copilot_executor && docker compose up -d copilot_executor`. Do NOT `docker exec ... npm install -g @anthropic-ai/claude-code` — that is outdated guidance and will pollute the container with a second CLI that the SDK won't use.
|
||||
|
||||
### Problem: agent-browser screenshot hangs / times out
|
||||
**Symptom:** `agent-browser screenshot` exits with code 124 even on `about:blank`.
|
||||
**Cause:** Stuck CDP connection or Chromium process tree. Seen on macOS when a prior `/pr-test` left a zombie Chrome for Testing.
|
||||
**Fix:** `pkill -9 -f "agent-browser|chromium|Chrome for Testing" && sleep 2`, then reopen the browser with a fresh `--session-name`. If still failing, verify via `agent-browser eval` + `agent-browser snapshot` (DOM state) instead of relying on PNGs — the feature under test is the same.
|
||||
|
||||
### Problem: Services not starting after `docker compose up`
|
||||
**Fix:** Wait and check health: `docker compose ps`. Common cause: migration hasn't finished. Check: `docker logs autogpt_platform-migrate-1 2>&1 | tail -5`. If supabase-db isn't healthy: `docker restart supabase-db && sleep 10`.
|
||||
|
||||
430
.github/workflows/cla-label-sync.yml
vendored
430
.github/workflows/cla-label-sync.yml
vendored
@@ -1,430 +0,0 @@
|
||||
name: CLA Label Sync
|
||||
|
||||
on:
|
||||
# Real-time: when CLA status changes (CLA-assistant uses Status API)
|
||||
status:
|
||||
|
||||
# When PRs are opened or updated
|
||||
pull_request_target:
|
||||
types: [opened, synchronize, reopened]
|
||||
|
||||
# Scheduled sweep - check stale PRs daily
|
||||
schedule:
|
||||
- cron: '0 9 * * *' # 9 AM UTC daily
|
||||
|
||||
# Manual trigger for testing
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
pr_number:
|
||||
description: 'Specific PR number to check (optional)'
|
||||
required: false
|
||||
|
||||
permissions:
|
||||
pull-requests: write
|
||||
issues: write
|
||||
contents: read
|
||||
statuses: read
|
||||
checks: read
|
||||
|
||||
env:
|
||||
CLA_CHECK_NAME: 'license/cla'
|
||||
LABEL_PENDING: 'cla: pending'
|
||||
LABEL_SIGNED: 'cla: signed'
|
||||
# Timing configuration (all independently configurable)
|
||||
REMINDER_DAYS: 7 # Days before first reminder
|
||||
CLOSE_WARNING_DAYS: 23 # Days before "closing soon" warning
|
||||
CLOSE_DAYS: 30 # Days before auto-close
|
||||
|
||||
jobs:
|
||||
sync-labels:
|
||||
runs-on: ubuntu-latest
|
||||
# Only run on status events if it's the CLA check
|
||||
if: github.event_name != 'status' || github.event.context == 'license/cla'
|
||||
|
||||
steps:
|
||||
- name: Ensure CLA labels exist
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const labels = [
|
||||
{ name: 'cla: pending', color: 'fbca04', description: 'CLA not yet signed by all contributors' },
|
||||
{ name: 'cla: signed', color: '0e8a16', description: 'CLA signed by all contributors' }
|
||||
];
|
||||
|
||||
for (const label of labels) {
|
||||
try {
|
||||
await github.rest.issues.getLabel({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
name: label.name
|
||||
});
|
||||
} catch (e) {
|
||||
if (e.status === 404) {
|
||||
await github.rest.issues.createLabel({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
name: label.name,
|
||||
color: label.color,
|
||||
description: label.description
|
||||
});
|
||||
console.log(`Created label: ${label.name}`);
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
- name: Sync CLA labels and handle stale PRs
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const CLA_CHECK_NAME = process.env.CLA_CHECK_NAME;
|
||||
const LABEL_PENDING = process.env.LABEL_PENDING;
|
||||
const LABEL_SIGNED = process.env.LABEL_SIGNED;
|
||||
const REMINDER_DAYS = parseInt(process.env.REMINDER_DAYS);
|
||||
const CLOSE_WARNING_DAYS = parseInt(process.env.CLOSE_WARNING_DAYS);
|
||||
const CLOSE_DAYS = parseInt(process.env.CLOSE_DAYS);
|
||||
|
||||
// Validate timing configuration
|
||||
if ([REMINDER_DAYS, CLOSE_WARNING_DAYS, CLOSE_DAYS].some(Number.isNaN)) {
|
||||
core.setFailed('Invalid timing configuration — REMINDER_DAYS, CLOSE_WARNING_DAYS, and CLOSE_DAYS must be numeric.');
|
||||
return;
|
||||
}
|
||||
if (!(REMINDER_DAYS < CLOSE_WARNING_DAYS && CLOSE_WARNING_DAYS < CLOSE_DAYS)) {
|
||||
core.warning(`Timing order looks odd: REMINDER(${REMINDER_DAYS}) < WARNING(${CLOSE_WARNING_DAYS}) < CLOSE(${CLOSE_DAYS}) expected.`);
|
||||
}
|
||||
|
||||
const CLA_SIGN_URL = `https://cla-assistant.io/${context.repo.owner}/${context.repo.repo}`;
|
||||
|
||||
// Helper: Get CLA status for a commit
|
||||
async function getClaStatus(headSha) {
|
||||
// CLA-assistant uses the commit status API (not checks API)
|
||||
const { data: statuses } = await github.rest.repos.getCombinedStatusForRef({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
ref: headSha
|
||||
});
|
||||
|
||||
const claStatus = statuses.statuses.find(
|
||||
s => s.context === CLA_CHECK_NAME
|
||||
);
|
||||
|
||||
if (claStatus) {
|
||||
return {
|
||||
found: true,
|
||||
passed: claStatus.state === 'success',
|
||||
state: claStatus.state,
|
||||
description: claStatus.description
|
||||
};
|
||||
}
|
||||
|
||||
// Fallback: check the Checks API too
|
||||
const { data: checkRuns } = await github.rest.checks.listForRef({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
ref: headSha
|
||||
});
|
||||
|
||||
const claCheck = checkRuns.check_runs.find(
|
||||
check => check.name === CLA_CHECK_NAME
|
||||
);
|
||||
|
||||
if (claCheck) {
|
||||
return {
|
||||
found: true,
|
||||
passed: claCheck.conclusion === 'success',
|
||||
state: claCheck.conclusion,
|
||||
description: claCheck.output?.summary || ''
|
||||
};
|
||||
}
|
||||
|
||||
return { found: false, passed: false, state: 'unknown' };
|
||||
}
|
||||
|
||||
// Helper: Check if bot already commented with a specific marker (paginated)
|
||||
async function hasCommentWithMarker(prNumber, marker) {
|
||||
// Use paginate to fetch ALL comments, not just first 100
|
||||
const comments = await github.paginate(
|
||||
github.rest.issues.listComments,
|
||||
{
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
per_page: 100
|
||||
}
|
||||
);
|
||||
|
||||
return comments.some(c =>
|
||||
c.user?.type === 'Bot' &&
|
||||
c.body?.includes(marker)
|
||||
);
|
||||
}
|
||||
|
||||
// Helper: Days since a date
|
||||
function daysSince(dateString) {
|
||||
const date = new Date(dateString);
|
||||
const now = new Date();
|
||||
return Math.floor((now - date) / (1000 * 60 * 60 * 24));
|
||||
}
|
||||
|
||||
// Determine which PRs to check
|
||||
let prsToCheck = [];
|
||||
|
||||
if (context.eventName === 'status') {
|
||||
// Status event from CLA-assistant - find PRs with this commit
|
||||
const sha = context.payload.sha;
|
||||
console.log(`Status event for SHA: ${sha}, context: ${context.payload.context}`);
|
||||
|
||||
// Search for open PRs with this head SHA (paginated)
|
||||
const allPRs = await github.paginate(
|
||||
github.rest.pulls.list,
|
||||
{
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
state: 'open',
|
||||
per_page: 100
|
||||
}
|
||||
);
|
||||
prsToCheck = allPRs.filter(pr => pr.head.sha === sha).map(pr => pr.number);
|
||||
|
||||
if (prsToCheck.length === 0) {
|
||||
console.log('No open PRs found with this SHA');
|
||||
return;
|
||||
}
|
||||
|
||||
} else if (context.eventName === 'pull_request_target') {
|
||||
prsToCheck = [context.payload.pull_request.number];
|
||||
|
||||
} else if (context.eventName === 'workflow_dispatch' && context.payload.inputs?.pr_number) {
|
||||
prsToCheck = [parseInt(context.payload.inputs.pr_number)];
|
||||
|
||||
} else {
|
||||
// Scheduled run: check all open PRs (paginated to handle >100 PRs)
|
||||
const openPRs = await github.paginate(
|
||||
github.rest.pulls.list,
|
||||
{
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
state: 'open',
|
||||
per_page: 100
|
||||
}
|
||||
);
|
||||
prsToCheck = openPRs.map(pr => pr.number);
|
||||
}
|
||||
|
||||
console.log(`Checking ${prsToCheck.length} PR(s): ${prsToCheck.join(', ')}`);
|
||||
|
||||
for (const prNumber of prsToCheck) {
|
||||
try {
|
||||
// Get PR details
|
||||
const { data: pr } = await github.rest.pulls.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber
|
||||
});
|
||||
|
||||
// Skip if PR is from a bot
|
||||
if (pr.user.type === 'Bot') {
|
||||
console.log(`PR #${prNumber}: Skipping bot PR`);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip if PR is not open (closed/merged)
|
||||
if (pr.state !== 'open') {
|
||||
console.log(`PR #${prNumber}: Skipping non-open PR (state=${pr.state})`);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip if PR already has cla: signed label (optimization for scheduled sweeps)
|
||||
const currentLabels = pr.labels.map(l => l.name);
|
||||
const knownPlatformPR = currentLabels.includes(LABEL_SIGNED) || currentLabels.includes(LABEL_PENDING);
|
||||
|
||||
// Skip listFiles if we've already labelled this PR (a previous run verified it touches platform code)
|
||||
if (!knownPlatformPR) {
|
||||
const PLATFORM_PATH = 'autogpt_platform/';
|
||||
const prFiles = await github.paginate(
|
||||
github.rest.pulls.listFiles,
|
||||
{
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber,
|
||||
per_page: 100
|
||||
}
|
||||
);
|
||||
const touchesPlatform = prFiles.some(f => f.filename.startsWith(PLATFORM_PATH));
|
||||
if (!touchesPlatform) {
|
||||
console.log(`PR #${prNumber}: Skipping - doesn't touch ${PLATFORM_PATH}`);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const claStatus = await getClaStatus(pr.head.sha);
|
||||
const hasPending = currentLabels.includes(LABEL_PENDING);
|
||||
const hasSigned = currentLabels.includes(LABEL_SIGNED);
|
||||
const prAgeDays = daysSince(pr.created_at);
|
||||
|
||||
console.log(`PR #${prNumber}: CLA ${claStatus.passed ? 'passed' : 'pending'} (${claStatus.state}), age: ${prAgeDays} days`);
|
||||
|
||||
if (claStatus.passed) {
|
||||
// ✅ CLA signed - add signed label, remove pending
|
||||
if (!hasSigned) {
|
||||
await github.rest.issues.addLabels({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
labels: [LABEL_SIGNED]
|
||||
});
|
||||
console.log(`Added '${LABEL_SIGNED}' to PR #${prNumber}`);
|
||||
}
|
||||
if (hasPending) {
|
||||
await github.rest.issues.removeLabel({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
name: LABEL_PENDING
|
||||
});
|
||||
console.log(`Removed '${LABEL_PENDING}' from PR #${prNumber}`);
|
||||
}
|
||||
|
||||
} else {
|
||||
// ⏳ CLA pending
|
||||
|
||||
// Add pending label if not present
|
||||
if (!hasPending) {
|
||||
await github.rest.issues.addLabels({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
labels: [LABEL_PENDING]
|
||||
});
|
||||
console.log(`Added '${LABEL_PENDING}' to PR #${prNumber}`);
|
||||
}
|
||||
if (hasSigned) {
|
||||
await github.rest.issues.removeLabel({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
name: LABEL_SIGNED
|
||||
});
|
||||
console.log(`Removed '${LABEL_SIGNED}' from PR #${prNumber}`);
|
||||
}
|
||||
|
||||
// Check if we need to send reminder or close
|
||||
const REMINDER_MARKER = '<!-- cla-reminder -->';
|
||||
const CLOSE_WARNING_MARKER = '<!-- cla-close-warning -->';
|
||||
|
||||
// 📢 Reminder after REMINDER_DAYS (but before warning window)
|
||||
if (prAgeDays >= REMINDER_DAYS && prAgeDays < CLOSE_WARNING_DAYS) {
|
||||
const hasReminder = await hasCommentWithMarker(prNumber, REMINDER_MARKER);
|
||||
|
||||
if (!hasReminder) {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: `${REMINDER_MARKER}
|
||||
|
||||
👋 **Friendly reminder:** This PR is waiting on a signed CLA.
|
||||
|
||||
All contributors need to sign our Contributor License Agreement before we can merge this PR.
|
||||
|
||||
**➡️ [Sign the CLA here](${CLA_SIGN_URL}?pullRequest=${prNumber})**
|
||||
|
||||
<details>
|
||||
<summary>Why do we need a CLA?</summary>
|
||||
|
||||
The CLA protects both you and the project by clarifying the terms under which your contribution is made. It's a one-time process — once signed, it covers all your future contributions.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Common issues</summary>
|
||||
|
||||
- **Email mismatch:** Make sure your Git commit email matches your GitHub account email
|
||||
- **Merge commits:** If you merged \`dev\` into your branch, try rebasing instead: \`git rebase origin/dev && git push --force-with-lease\`
|
||||
- **Multiple authors:** All commit authors need to sign, not just the PR author
|
||||
|
||||
</details>
|
||||
|
||||
If you have questions, just ask! 🙂`
|
||||
});
|
||||
console.log(`Posted reminder on PR #${prNumber}`);
|
||||
}
|
||||
}
|
||||
|
||||
// ⚠️ Close warning at CLOSE_WARNING_DAYS
|
||||
if (prAgeDays >= CLOSE_WARNING_DAYS && prAgeDays < CLOSE_DAYS) {
|
||||
const hasCloseWarning = await hasCommentWithMarker(prNumber, CLOSE_WARNING_MARKER);
|
||||
|
||||
if (!hasCloseWarning) {
|
||||
const daysRemaining = CLOSE_DAYS - prAgeDays;
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: `${CLOSE_WARNING_MARKER}
|
||||
|
||||
⚠️ **This PR will be automatically closed in ${daysRemaining} day${daysRemaining === 1 ? '' : 's'}** if the CLA is not signed.
|
||||
|
||||
We haven't received a signed CLA from all contributors yet. Please sign it to keep this PR open:
|
||||
|
||||
**➡️ [Sign the CLA here](${CLA_SIGN_URL}?pullRequest=${prNumber})**
|
||||
|
||||
If you're unable to sign or have questions, please let us know — we're happy to help!`
|
||||
});
|
||||
console.log(`Posted close warning on PR #${prNumber}`);
|
||||
}
|
||||
}
|
||||
|
||||
// 🚪 Auto-close after CLOSE_DAYS
|
||||
if (prAgeDays >= CLOSE_DAYS) {
|
||||
const CLOSE_MARKER = '<!-- cla-auto-closed -->';
|
||||
const OVERRIDE_LABEL = 'cla: override';
|
||||
|
||||
// Check for override label (maintainer wants to keep PR open)
|
||||
if (currentLabels.includes(OVERRIDE_LABEL)) {
|
||||
console.log(`PR #${prNumber}: Skipping close due to '${OVERRIDE_LABEL}' label`);
|
||||
} else {
|
||||
// Check if we already posted a close comment
|
||||
const hasCloseComment = await hasCommentWithMarker(prNumber, CLOSE_MARKER);
|
||||
|
||||
if (!hasCloseComment) {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: `${CLOSE_MARKER}
|
||||
|
||||
👋 Closing this PR due to unsigned CLA after ${CLOSE_DAYS} days.
|
||||
|
||||
Thank you for your contribution! If you'd still like to contribute:
|
||||
|
||||
1. [Sign the CLA](${CLA_SIGN_URL})
|
||||
2. Re-open this PR or create a new one
|
||||
|
||||
> **Maintainers:** To reopen and exempt from future auto-close, add the \`cla: override\` label before reopening. Without it, the PR will not be re-closed automatically (a reopened PR is treated as a maintainer decision).
|
||||
|
||||
We appreciate your interest in AutoGPT and hope to see you back! 🚀`
|
||||
});
|
||||
|
||||
await github.rest.pulls.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber,
|
||||
state: 'closed'
|
||||
});
|
||||
|
||||
console.log(`Closed PR #${prNumber} due to unsigned CLA`);
|
||||
} else {
|
||||
console.log(`PR #${prNumber}: Already auto-closed previously, skipping (maintainer may have reopened)`);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.error(`Error processing PR #${prNumber}: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
console.log('CLA label sync complete!');
|
||||
@@ -32,10 +32,10 @@ router = APIRouter(
|
||||
class UserRateLimitResponse(BaseModel):
|
||||
user_id: str
|
||||
user_email: Optional[str] = None
|
||||
daily_token_limit: int
|
||||
weekly_token_limit: int
|
||||
daily_tokens_used: int
|
||||
weekly_tokens_used: int
|
||||
daily_cost_limit_microdollars: int
|
||||
weekly_cost_limit_microdollars: int
|
||||
daily_cost_used_microdollars: int
|
||||
weekly_cost_used_microdollars: int
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
@@ -101,17 +101,19 @@ async def get_user_rate_limit(
|
||||
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
|
||||
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
resolved_id, config.daily_token_limit, config.weekly_token_limit
|
||||
resolved_id,
|
||||
config.daily_cost_limit_microdollars,
|
||||
config.weekly_cost_limit_microdollars,
|
||||
)
|
||||
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=resolved_id,
|
||||
user_email=resolved_email,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
daily_cost_limit_microdollars=daily_limit,
|
||||
weekly_cost_limit_microdollars=weekly_limit,
|
||||
daily_cost_used_microdollars=usage.daily.used,
|
||||
weekly_cost_used_microdollars=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
@@ -141,7 +143,9 @@ async def reset_user_rate_limit(
|
||||
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
|
||||
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
user_id,
|
||||
config.daily_cost_limit_microdollars,
|
||||
config.weekly_cost_limit_microdollars,
|
||||
)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
@@ -154,10 +158,10 @@ async def reset_user_rate_limit(
|
||||
return UserRateLimitResponse(
|
||||
user_id=user_id,
|
||||
user_email=resolved_email,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
daily_cost_limit_microdollars=daily_limit,
|
||||
weekly_cost_limit_microdollars=weekly_limit,
|
||||
daily_cost_used_microdollars=usage.daily.used,
|
||||
weekly_cost_used_microdollars=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@@ -85,10 +85,10 @@ def test_get_rate_limit(
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] == _TARGET_EMAIL
|
||||
assert data["daily_token_limit"] == 2_500_000
|
||||
assert data["weekly_token_limit"] == 12_500_000
|
||||
assert data["daily_tokens_used"] == 500_000
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
assert data["daily_cost_limit_microdollars"] == 2_500_000
|
||||
assert data["weekly_cost_limit_microdollars"] == 12_500_000
|
||||
assert data["daily_cost_used_microdollars"] == 500_000
|
||||
assert data["weekly_cost_used_microdollars"] == 3_000_000
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
@@ -117,7 +117,7 @@ def test_get_rate_limit_by_email(
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] == _TARGET_EMAIL
|
||||
assert data["daily_token_limit"] == 2_500_000
|
||||
assert data["daily_cost_limit_microdollars"] == 2_500_000
|
||||
|
||||
|
||||
def test_get_rate_limit_by_email_not_found(
|
||||
@@ -160,9 +160,9 @@ def test_reset_user_usage_daily_only(
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily_tokens_used"] == 0
|
||||
assert data["daily_cost_used_microdollars"] == 0
|
||||
# Weekly is untouched
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
assert data["weekly_cost_used_microdollars"] == 3_000_000
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
|
||||
@@ -192,8 +192,8 @@ def test_reset_user_usage_daily_and_weekly(
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily_tokens_used"] == 0
|
||||
assert data["weekly_tokens_used"] == 0
|
||||
assert data["daily_cost_used_microdollars"] == 0
|
||||
assert data["weekly_cost_used_microdollars"] == 0
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
|
||||
|
||||
@@ -2,15 +2,13 @@
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
from uuid import uuid4
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
@@ -18,7 +16,6 @@ from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
|
||||
from backend.copilot.db import get_chat_messages_paginated
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
from backend.copilot.message_dedup import acquire_dedup_lock
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
@@ -30,8 +27,14 @@ from backend.copilot.model import (
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.pending_message_helpers import (
|
||||
QueuePendingMessageResponse,
|
||||
is_turn_in_flight,
|
||||
queue_pending_for_http,
|
||||
)
|
||||
from backend.copilot.pending_messages import peek_pending_messages
|
||||
from backend.copilot.rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
CoPilotUsagePublic,
|
||||
RateLimitExceeded,
|
||||
acquire_reset_lock,
|
||||
check_rate_limit,
|
||||
@@ -76,7 +79,7 @@ from backend.copilot.tracking import track_user_message
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import get_business_understanding
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.data.workspace import build_files_block, resolve_workspace_files
|
||||
from backend.util.exceptions import InsufficientBalanceError, NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -86,10 +89,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
|
||||
)
|
||||
|
||||
|
||||
async def _validate_and_get_session(
|
||||
session_id: str,
|
||||
@@ -152,6 +151,19 @@ class StreamChatRequest(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class PeekPendingMessagesResponse(BaseModel):
|
||||
"""Response for the pending-message peek (GET) endpoint.
|
||||
|
||||
Returns a read-only view of the pending buffer — messages are NOT
|
||||
consumed. The frontend uses this to restore the queued-message
|
||||
indicator after a page refresh and to decide when to clear it once
|
||||
a turn has ended.
|
||||
"""
|
||||
|
||||
messages: list[str]
|
||||
count: int
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""Request model for creating a new chat session.
|
||||
|
||||
@@ -463,22 +475,13 @@ async def get_session(
|
||||
|
||||
Supports cursor-based pagination via ``limit`` and ``before_sequence``.
|
||||
When no pagination params are provided, returns the most recent messages.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
user_id: The authenticated user's ID.
|
||||
limit: Maximum number of messages to return (1-200, default 50).
|
||||
before_sequence: Return messages with sequence < this value (cursor).
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session, including
|
||||
active_stream info and pagination metadata.
|
||||
"""
|
||||
page = await get_chat_messages_paginated(
|
||||
session_id, limit, before_sequence, user_id=user_id
|
||||
)
|
||||
if page is None:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
|
||||
messages = [
|
||||
_strip_injected_context(message.model_dump()) for message in page.messages
|
||||
]
|
||||
@@ -489,10 +492,6 @@ async def get_session(
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||
)
|
||||
if active_session:
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
turn_id=active_session.turn_id,
|
||||
@@ -537,23 +536,27 @@ async def get_session(
|
||||
)
|
||||
async def get_copilot_usage(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> CoPilotUsageStatus:
|
||||
) -> CoPilotUsagePublic:
|
||||
"""Get CoPilot usage status for the authenticated user.
|
||||
|
||||
Returns current token usage vs limits for daily and weekly windows.
|
||||
Global defaults sourced from LaunchDarkly (falling back to config).
|
||||
Includes the user's rate-limit tier.
|
||||
Returns the percentage of the daily/weekly allowance used — not the
|
||||
raw spend or cap — so clients cannot derive per-turn cost or platform
|
||||
margins. Global defaults sourced from LaunchDarkly (falling back to
|
||||
config). Includes the user's rate-limit tier.
|
||||
"""
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
user_id,
|
||||
config.daily_cost_limit_microdollars,
|
||||
config.weekly_cost_limit_microdollars,
|
||||
)
|
||||
return await get_usage_status(
|
||||
status = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_cost_limit=daily_limit,
|
||||
weekly_cost_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
tier=tier,
|
||||
)
|
||||
return CoPilotUsagePublic.from_status(status)
|
||||
|
||||
|
||||
class RateLimitResetResponse(BaseModel):
|
||||
@@ -562,7 +565,9 @@ class RateLimitResetResponse(BaseModel):
|
||||
success: bool
|
||||
credits_charged: int = Field(description="Credits charged (in cents)")
|
||||
remaining_balance: int = Field(description="Credit balance after charge (in cents)")
|
||||
usage: CoPilotUsageStatus = Field(description="Updated usage status after reset")
|
||||
usage: CoPilotUsagePublic = Field(
|
||||
description="Updated usage status after reset (percentages only)"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -586,7 +591,7 @@ async def reset_copilot_usage(
|
||||
) -> RateLimitResetResponse:
|
||||
"""Reset the daily CoPilot rate limit by spending credits.
|
||||
|
||||
Allows users who have hit their daily token limit to spend credits
|
||||
Allows users who have hit their daily cost limit to spend credits
|
||||
to reset their daily usage counter and continue working.
|
||||
Returns 400 if the feature is disabled or the user is not over the limit.
|
||||
Returns 402 if the user has insufficient credits.
|
||||
@@ -605,7 +610,9 @@ async def reset_copilot_usage(
|
||||
)
|
||||
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
user_id,
|
||||
config.daily_cost_limit_microdollars,
|
||||
config.weekly_cost_limit_microdollars,
|
||||
)
|
||||
|
||||
if daily_limit <= 0:
|
||||
@@ -642,8 +649,8 @@ async def reset_copilot_usage(
|
||||
# used for limit checks, not returned to the client.)
|
||||
usage_status = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_cost_limit=daily_limit,
|
||||
weekly_cost_limit=weekly_limit,
|
||||
tier=tier,
|
||||
)
|
||||
if daily_limit > 0 and usage_status.daily.used < daily_limit:
|
||||
@@ -678,7 +685,7 @@ async def reset_copilot_usage(
|
||||
|
||||
# Reset daily usage in Redis. If this fails, refund the credits
|
||||
# so the user is not charged for a service they did not receive.
|
||||
if not await reset_daily_usage(user_id, daily_token_limit=daily_limit):
|
||||
if not await reset_daily_usage(user_id, daily_cost_limit=daily_limit):
|
||||
# Compensate: refund the charged credits.
|
||||
refunded = False
|
||||
try:
|
||||
@@ -714,11 +721,11 @@ async def reset_copilot_usage(
|
||||
finally:
|
||||
await release_reset_lock(user_id)
|
||||
|
||||
# Return updated usage status.
|
||||
# Return updated usage status (public schema — percentages only).
|
||||
updated_usage = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_cost_limit=daily_limit,
|
||||
weekly_cost_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
tier=tier,
|
||||
)
|
||||
@@ -727,7 +734,7 @@ async def reset_copilot_usage(
|
||||
success=True,
|
||||
credits_charged=cost,
|
||||
remaining_balance=remaining,
|
||||
usage=updated_usage,
|
||||
usage=CoPilotUsagePublic.from_status(updated_usage),
|
||||
)
|
||||
|
||||
|
||||
@@ -778,36 +785,52 @@ async def cancel_session_task(
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/stream",
|
||||
responses={
|
||||
202: {
|
||||
"model": QueuePendingMessageResponse,
|
||||
"description": (
|
||||
"Session has a turn in flight — message queued into the pending "
|
||||
"buffer and will be picked up between tool-call rounds by the "
|
||||
"executor currently processing the turn."
|
||||
),
|
||||
},
|
||||
404: {"description": "Session not found or access denied"},
|
||||
429: {"description": "Cost rate-limit or call-frequency cap exceeded"},
|
||||
},
|
||||
)
|
||||
async def stream_chat_post(
|
||||
session_id: str,
|
||||
request: StreamChatRequest,
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session (POST with context support).
|
||||
"""Start a new turn OR queue a follow-up — decided server-side.
|
||||
|
||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||
- Text fragments as they are generated
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
- **Session idle**: starts a turn. Returns an SSE stream (``text/event-stream``)
|
||||
with Vercel AI SDK chunks (text fragments, tool-call UI, tool results).
|
||||
The generation runs in a background task that survives client disconnects;
|
||||
reconnect via ``GET /sessions/{session_id}/stream`` to resume.
|
||||
|
||||
The AI generation runs in a background task that continues even if the client disconnects.
|
||||
All chunks are written to a per-turn Redis stream for reconnection support. If the client
|
||||
disconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.
|
||||
- **Session has a turn in flight**: pushes the message into the per-session
|
||||
pending buffer and returns ``202 application/json`` with
|
||||
``QueuePendingMessageResponse``. The executor running the current turn
|
||||
drains the buffer between tool-call rounds (baseline) or at the start of
|
||||
the next turn (SDK). Clients should detect the 202 and surface the
|
||||
message as a queued-chip in the UI.
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
session_id: The chat session identifier.
|
||||
request: Request body with message, is_user_message, and optional context.
|
||||
user_id: Authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
stream_start_time = time.perf_counter()
|
||||
# Wall-clock arrival time, propagated to the executor so the turn-start
|
||||
# drain can order pending messages relative to this request (pending
|
||||
# pushed BEFORE this instant were typed earlier; pending pushed AFTER
|
||||
# are race-path follow-ups typed while /stream was still processing).
|
||||
request_arrival_at = time.time()
|
||||
log_meta = {"component": "ChatStream", "session_id": session_id, "user_id": user_id}
|
||||
|
||||
logger.info(
|
||||
@@ -816,6 +839,26 @@ async def stream_chat_post(
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
# Self-defensive queue-fallback: if a turn is already running, don't race
|
||||
# it on the cluster lock — drop the message into the pending buffer and
|
||||
# return 202 so the caller can render a chip. Both UI chips and autopilot
|
||||
# block follow-ups route through this path; keeping the decision on the
|
||||
# server means every caller gets uniform behaviour.
|
||||
if (
|
||||
request.is_user_message
|
||||
and request.message
|
||||
and await is_turn_in_flight(session_id)
|
||||
):
|
||||
response = await queue_pending_for_http(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=request.message,
|
||||
context=request.context,
|
||||
file_ids=request.file_ids,
|
||||
)
|
||||
return JSONResponse(status_code=202, content=response.model_dump())
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
||||
extra={
|
||||
@@ -826,18 +869,20 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Pre-turn rate limit check (token-based).
|
||||
# Pre-turn rate limit check (cost-based, microdollars).
|
||||
# check_rate_limit short-circuits internally when both limits are 0.
|
||||
# Global defaults sourced from LaunchDarkly, falling back to config.
|
||||
if user_id:
|
||||
try:
|
||||
daily_limit, weekly_limit, _ = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
user_id,
|
||||
config.daily_cost_limit_microdollars,
|
||||
config.weekly_cost_limit_microdollars,
|
||||
)
|
||||
await check_rate_limit(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_cost_limit=daily_limit,
|
||||
weekly_cost_limit=weekly_limit,
|
||||
)
|
||||
except RateLimitExceeded as e:
|
||||
raise HTTPException(status_code=429, detail=str(e)) from e
|
||||
@@ -846,89 +891,41 @@ async def stream_chat_post(
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
sanitized_file_ids: list[str] | None = None
|
||||
# Capture the original message text BEFORE any mutation (attachment enrichment)
|
||||
# so the idempotency hash is stable across retries.
|
||||
original_message = request.message
|
||||
if request.file_ids and user_id:
|
||||
# Filter to valid UUIDs only to prevent DB abuse
|
||||
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
||||
|
||||
if valid_ids:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Batch query instead of N+1
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": valid_ids},
|
||||
"workspaceId": workspace.id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
# Only keep IDs that actually exist in the user's workspace
|
||||
sanitized_file_ids = [wf.id for wf in files] or None
|
||||
file_lines: list[str] = [
|
||||
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
||||
for wf in files
|
||||
]
|
||||
if file_lines:
|
||||
files_block = (
|
||||
"\n\n[Attached files]\n"
|
||||
+ "\n".join(file_lines)
|
||||
+ "\nUse read_workspace_file with the file_id to access file contents."
|
||||
)
|
||||
request.message += files_block
|
||||
|
||||
# ── Idempotency guard ────────────────────────────────────────────────────
|
||||
# Blocks duplicate executor tasks from concurrent/retried POSTs.
|
||||
# See backend/copilot/message_dedup.py for the full lifecycle description.
|
||||
dedup_lock = None
|
||||
if request.is_user_message:
|
||||
dedup_lock = await acquire_dedup_lock(
|
||||
session_id, original_message, sanitized_file_ids
|
||||
)
|
||||
if dedup_lock is None and (original_message or sanitized_file_ids):
|
||||
|
||||
async def _empty_sse() -> AsyncGenerator[str, None]:
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
_empty_sse(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Connection": "keep-alive",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
if request.file_ids:
|
||||
files = await resolve_workspace_files(user_id, request.file_ids)
|
||||
sanitized_file_ids = [wf.id for wf in files] or None
|
||||
request.message += build_files_block(files)
|
||||
|
||||
# Atomically append user message to session BEFORE creating task to avoid
|
||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
||||
# message loss from concurrent requests.
|
||||
#
|
||||
# If any of these operations raises, release the dedup lock before propagating
|
||||
# so subsequent retries are not blocked for 30 s.
|
||||
try:
|
||||
if request.message:
|
||||
message = ChatMessage(
|
||||
role="user" if request.is_user_message else "assistant",
|
||||
content=request.message,
|
||||
)
|
||||
if request.is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
)
|
||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||
# saved yet. append_and_save_message returns None when a duplicate is
|
||||
# detected — in that case skip enqueue to avoid processing the message twice.
|
||||
is_duplicate_message = False
|
||||
if request.message:
|
||||
message = ChatMessage(
|
||||
role="user" if request.is_user_message else "assistant",
|
||||
content=request.message,
|
||||
)
|
||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||
is_duplicate_message = (
|
||||
await append_and_save_message(session_id, message)
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
) is None
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
if not is_duplicate_message and request.is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
)
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
# Create a task in the stream registry for reconnection support.
|
||||
# For duplicate messages, skip create_session entirely so the infra-retry
|
||||
# client subscribes to the *existing* turn's Redis stream and receives the
|
||||
# in-progress executor output rather than an empty stream.
|
||||
turn_id = ""
|
||||
if not is_duplicate_message:
|
||||
turn_id = str(uuid4())
|
||||
log_meta["turn_id"] = turn_id
|
||||
|
||||
session_create_start = time.perf_counter()
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
@@ -946,7 +943,6 @@ async def stream_chat_post(
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
await enqueue_copilot_turn(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
@@ -957,11 +953,12 @@ async def stream_chat_post(
|
||||
file_ids=sanitized_file_ids,
|
||||
mode=request.mode,
|
||||
model=request.model,
|
||||
request_arrival_at=request_arrival_at,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"[STREAM] Duplicate message detected for session {session_id}, skipping enqueue"
|
||||
)
|
||||
except Exception:
|
||||
if dedup_lock:
|
||||
await dedup_lock.release()
|
||||
raise
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
logger.info(
|
||||
@@ -985,12 +982,6 @@ async def stream_chat_post(
|
||||
subscriber_queue = None
|
||||
first_chunk_yielded = False
|
||||
chunks_yielded = 0
|
||||
# True for every exit path except GeneratorExit (client disconnect).
|
||||
# On disconnect the backend turn is still running — releasing the lock
|
||||
# there would reopen the infra-retry duplicate window. The 30 s TTL
|
||||
# is the fallback. All other exits (normal finish, early return, error)
|
||||
# should release so the user can re-send the same message.
|
||||
release_dedup_lock_on_exit = True
|
||||
try:
|
||||
# Subscribe from the position we captured before enqueuing
|
||||
# This avoids replaying old messages while catching all new ones
|
||||
@@ -1002,7 +993,7 @@ async def stream_chat_post(
|
||||
|
||||
if subscriber_queue is None:
|
||||
yield StreamFinish().to_sse()
|
||||
return # finally releases dedup_lock
|
||||
return
|
||||
|
||||
# Read from the subscriber queue and yield to SSE
|
||||
logger.info(
|
||||
@@ -1044,7 +1035,7 @@ async def stream_chat_post(
|
||||
}
|
||||
},
|
||||
)
|
||||
break # finally releases dedup_lock
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
yield StreamHeartbeat().to_sse()
|
||||
@@ -1060,7 +1051,6 @@ async def stream_chat_post(
|
||||
}
|
||||
},
|
||||
)
|
||||
release_dedup_lock_on_exit = False
|
||||
except Exception as e:
|
||||
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
||||
logger.error(
|
||||
@@ -1075,10 +1065,7 @@ async def stream_chat_post(
|
||||
code="stream_error",
|
||||
).to_sse()
|
||||
yield StreamFinish().to_sse()
|
||||
# finally releases dedup_lock
|
||||
finally:
|
||||
if dedup_lock and release_dedup_lock_on_exit:
|
||||
await dedup_lock.release()
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
@@ -1117,6 +1104,31 @@ async def stream_chat_post(
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}/messages/pending",
|
||||
response_model=PeekPendingMessagesResponse,
|
||||
responses={
|
||||
404: {"description": "Session not found or access denied"},
|
||||
},
|
||||
)
|
||||
async def get_pending_messages(
|
||||
session_id: str,
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
):
|
||||
"""Peek at the pending-message buffer without consuming it.
|
||||
|
||||
Returns the current contents of the session's pending message buffer
|
||||
so the frontend can restore the queued-message indicator after a page
|
||||
refresh and clear it correctly once a turn drains the buffer.
|
||||
"""
|
||||
await _validate_and_get_session(session_id, user_id)
|
||||
pending = await peek_pending_messages(session_id)
|
||||
return PeekPendingMessagesResponse(
|
||||
messages=[m.content for m in pending],
|
||||
count=len(pending),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
|
||||
@@ -133,21 +133,12 @@ def test_stream_chat_rejects_too_many_file_ids():
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def _mock_stream_internals(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
redis_set_returns: object = True,
|
||||
):
|
||||
def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
|
||||
"""Mock the async internals of stream_chat_post so tests can exercise
|
||||
validation and enrichment logic without needing Redis/RabbitMQ.
|
||||
|
||||
Args:
|
||||
redis_set_returns: Value returned by the mocked Redis ``set`` call.
|
||||
``True`` (default) simulates a fresh key (new message);
|
||||
``None`` simulates a collision (duplicate blocked).
|
||||
validation and enrichment logic without needing RabbitMQ.
|
||||
|
||||
Returns:
|
||||
A namespace with ``redis``, ``save``, and ``enqueue`` mock objects so
|
||||
A namespace with ``save`` and ``enqueue`` mock objects so
|
||||
callers can make additional assertions about side-effects.
|
||||
"""
|
||||
import types
|
||||
@@ -158,7 +149,7 @@ def _mock_stream_internals(
|
||||
)
|
||||
mock_save = mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
return_value=MagicMock(), # non-None = message was saved (not a duplicate)
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = mocker.AsyncMock(return_value=None)
|
||||
@@ -174,15 +165,9 @@ def _mock_stream_internals(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=redis_set_returns)
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_redis,
|
||||
return types.SimpleNamespace(
|
||||
save=mock_save, enqueue=mock_enqueue, registry=mock_registry
|
||||
)
|
||||
ns = types.SimpleNamespace(redis=mock_redis, save=mock_save, enqueue=mock_enqueue)
|
||||
return ns
|
||||
|
||||
|
||||
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
|
||||
@@ -190,7 +175,7 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
|
||||
_mock_stream_internals(mocker)
|
||||
# Patch workspace lookup as imported by the routes module
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
"backend.data.workspace.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
mock_prisma = mocker.MagicMock()
|
||||
@@ -211,6 +196,29 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ─── Duplicate message dedup ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_skips_enqueue_for_duplicate_message(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
):
|
||||
"""When append_and_save_message returns None (duplicate detected),
|
||||
enqueue_copilot_turn and stream_registry.create_session must NOT be called
|
||||
to avoid double-processing and to prevent overwriting the active stream's
|
||||
turn_id in Redis (which would cause reconnecting clients to miss the response)."""
|
||||
mocks = _mock_stream_internals(mocker)
|
||||
# Override save to return None — signalling a duplicate
|
||||
mocks.save.return_value = None
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mocks.enqueue.assert_not_called()
|
||||
mocks.registry.create_session.assert_not_called()
|
||||
|
||||
|
||||
# ─── UUID format filtering ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -219,7 +227,7 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockerFixture):
|
||||
and NOT passed to the database query."""
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
"backend.data.workspace.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
|
||||
@@ -257,7 +265,7 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockerFixture):
|
||||
"""The batch query should scope to the user's workspace."""
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
"backend.data.workspace.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "my-workspace-id"})(),
|
||||
)
|
||||
|
||||
@@ -288,8 +296,8 @@ def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockerF
|
||||
|
||||
_mock_stream_internals(mocker)
|
||||
# Ensure the rate-limit branch is entered by setting a non-zero limit.
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)),
|
||||
@@ -310,8 +318,8 @@ def test_stream_chat_returns_429_on_weekly_rate_limit(
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000)
|
||||
resets_at = datetime.now(UTC) + timedelta(days=3)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
@@ -333,8 +341,8 @@ def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockerFixture):
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
side_effect=RateLimitExceeded(
|
||||
@@ -394,23 +402,33 @@ def test_usage_returns_daily_and_weekly(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""GET /usage returns daily and weekly usage."""
|
||||
"""GET /usage returns percentages for daily and weekly windows only.
|
||||
|
||||
The raw used/limit microdollar values MUST NOT leak — clients should not
|
||||
be able to derive per-turn cost or platform margins from the public API.
|
||||
"""
|
||||
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily"]["used"] == 500
|
||||
assert data["weekly"]["used"] == 2000
|
||||
# 500 / 10000 = 5%, 2000 / 50000 = 4%
|
||||
assert data["daily"]["percent_used"] == 5.0
|
||||
assert data["weekly"]["percent_used"] == 4.0
|
||||
# Raw spend/limit must not be exposed.
|
||||
assert "used" not in data["daily"]
|
||||
assert "limit" not in data["daily"]
|
||||
assert "used" not in data["weekly"]
|
||||
assert "limit" not in data["weekly"]
|
||||
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=10000,
|
||||
weekly_token_limit=50000,
|
||||
daily_cost_limit=10000,
|
||||
weekly_cost_limit=50000,
|
||||
rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost,
|
||||
tier=SubscriptionTier.FREE,
|
||||
)
|
||||
@@ -430,8 +448,8 @@ def test_usage_uses_config_limits(
|
||||
assert response.status_code == 200
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=99999,
|
||||
weekly_token_limit=77777,
|
||||
daily_cost_limit=99999,
|
||||
weekly_cost_limit=77777,
|
||||
rate_limit_reset_cost=500,
|
||||
tier=SubscriptionTier.FREE,
|
||||
)
|
||||
@@ -609,6 +627,246 @@ class TestStreamChatRequestModeValidation:
|
||||
assert req.mode is None
|
||||
|
||||
|
||||
# ─── POST /stream queue-fallback (when a turn is already in flight) ──
|
||||
|
||||
|
||||
def _mock_stream_queue_internals(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
session_exists: bool = True,
|
||||
turn_in_flight: bool = True,
|
||||
call_count: int = 1,
|
||||
):
|
||||
"""Mock dependencies for the POST /stream queue-fallback path.
|
||||
|
||||
When ``turn_in_flight`` is True the handler takes the 202 queue branch.
|
||||
"""
|
||||
if session_exists:
|
||||
mock_session = mocker.MagicMock()
|
||||
mock_session.id = "sess-1"
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_session,
|
||||
)
|
||||
else:
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
side_effect=fastapi.HTTPException(
|
||||
status_code=404, detail="Session not found."
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.is_turn_in_flight",
|
||||
new_callable=AsyncMock,
|
||||
return_value=turn_in_flight,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(0, 0, None),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.pending_message_helpers.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.pending_message_helpers.incr_with_ttl",
|
||||
new_callable=AsyncMock,
|
||||
return_value=call_count,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.pending_message_helpers.push_pending_message",
|
||||
new_callable=AsyncMock,
|
||||
return_value=1,
|
||||
)
|
||||
# queue_user_message re-runs is_turn_in_flight via the helper module —
|
||||
# stub that path out too so we don't need a fake stream_registry.
|
||||
mocker.patch(
|
||||
"backend.copilot.pending_message_helpers.get_active_session_meta",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
|
||||
def test_stream_queue_returns_202_when_turn_in_flight(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Happy path: POST /stream to a session with a live turn → 202 queue."""
|
||||
_mock_stream_queue_internals(mocker)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "follow-up", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
data = response.json()
|
||||
assert data["buffer_length"] == 1
|
||||
assert "turn_in_flight" in data
|
||||
|
||||
|
||||
def test_stream_queue_session_not_found_returns_404(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""If the session doesn't exist or belong to the user, returns 404."""
|
||||
_mock_stream_queue_internals(mocker, session_exists=False)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/bad-sess/stream",
|
||||
json={"message": "hi", "is_user_message": True},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_stream_queue_call_frequency_limit_returns_429(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Per-user call-frequency cap rejects rapid-fire queued pushes."""
|
||||
from backend.copilot.pending_message_helpers import PENDING_CALL_LIMIT
|
||||
|
||||
_mock_stream_queue_internals(mocker, call_count=PENDING_CALL_LIMIT + 1)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hi", "is_user_message": True},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
assert "Too many queued message requests this minute" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_stream_queue_converts_context_dict_to_pending_context(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""StreamChatRequest.context is a raw dict; must be coerced to the
|
||||
typed PendingMessageContext before being pushed onto the buffer."""
|
||||
_mock_stream_queue_internals(mocker)
|
||||
queue_spy = mocker.patch(
|
||||
"backend.copilot.pending_message_helpers.queue_user_message",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
from backend.copilot.pending_message_helpers import QueuePendingMessageResponse
|
||||
|
||||
queue_spy.return_value = QueuePendingMessageResponse(
|
||||
buffer_length=1, max_buffer_length=10, turn_in_flight=True
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "hi",
|
||||
"is_user_message": True,
|
||||
"context": {"url": "https://example.test", "content": "body"},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
queue_spy.assert_awaited_once()
|
||||
kwargs = queue_spy.await_args.kwargs
|
||||
from backend.copilot.pending_messages import PendingMessageContext
|
||||
|
||||
assert isinstance(kwargs["context"], PendingMessageContext)
|
||||
assert kwargs["context"].url == "https://example.test"
|
||||
assert kwargs["context"].content == "body"
|
||||
|
||||
|
||||
def test_stream_queue_passes_none_context_when_omitted(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""When request.context is omitted, the queue call receives context=None."""
|
||||
_mock_stream_queue_internals(mocker)
|
||||
queue_spy = mocker.patch(
|
||||
"backend.copilot.pending_message_helpers.queue_user_message",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
from backend.copilot.pending_message_helpers import QueuePendingMessageResponse
|
||||
|
||||
queue_spy.return_value = QueuePendingMessageResponse(
|
||||
buffer_length=1, max_buffer_length=10, turn_in_flight=True
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hi", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
queue_spy.assert_awaited_once()
|
||||
assert queue_spy.await_args.kwargs["context"] is None
|
||||
|
||||
|
||||
# ─── get_pending_messages (GET /sessions/{session_id}/messages/pending) ─────
|
||||
|
||||
|
||||
def test_get_pending_messages_returns_200_with_empty_buffer(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Happy path: no pending messages returns 200 with empty list."""
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.peek_pending_messages",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.get("/sessions/sess-1/messages/pending")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["messages"] == []
|
||||
assert data["count"] == 0
|
||||
|
||||
|
||||
def test_get_pending_messages_returns_queued_messages(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Returns pending messages from buffer without consuming them."""
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.peek_pending_messages",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[
|
||||
MagicMock(content="first message"),
|
||||
MagicMock(content="second message"),
|
||||
],
|
||||
)
|
||||
|
||||
response = client.get("/sessions/sess-1/messages/pending")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["count"] == 2
|
||||
assert data["messages"] == ["first message", "second message"]
|
||||
|
||||
|
||||
def test_get_pending_messages_session_not_found_returns_404(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""If session does not exist or belongs to another user, returns 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
side_effect=fastapi.HTTPException(status_code=404, detail="Session not found."),
|
||||
)
|
||||
|
||||
response = client.get("/sessions/bad-sess/messages/pending")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestStripInjectedContext:
|
||||
"""Unit tests for `_strip_injected_context` — the GET-side helper that
|
||||
hides the server-injected `<user_context>` block from API responses.
|
||||
@@ -706,237 +964,6 @@ class TestStripInjectedContext:
|
||||
assert result["content"] == "hello"
|
||||
|
||||
|
||||
# ─── Idempotency / duplicate-POST guard ──────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_blocks_duplicate_post_returns_empty_sse(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""A second POST with the same message within the 30-s window must return
|
||||
an empty SSE stream (StreamFinish + [DONE]) so the frontend marks the
|
||||
turn complete without creating a ghost response."""
|
||||
# redis_set_returns=None simulates a collision: the NX key already exists.
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=None)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-dup/stream",
|
||||
json={"message": "duplicate message", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.text
|
||||
# The response must contain StreamFinish (type=finish) and the SSE [DONE] terminator.
|
||||
assert '"finish"' in body
|
||||
assert "[DONE]" in body
|
||||
# The empty SSE response must include the AI SDK protocol header so the
|
||||
# frontend treats it as a valid stream and marks the turn complete.
|
||||
assert response.headers.get("x-vercel-ai-ui-message-stream") == "v1"
|
||||
# The duplicate guard must prevent save/enqueue side effects.
|
||||
ns.save.assert_not_called()
|
||||
ns.enqueue.assert_not_called()
|
||||
|
||||
|
||||
def test_stream_chat_first_post_proceeds_normally(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The first POST (Redis NX key set successfully) must proceed through the
|
||||
normal streaming path — no early return."""
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=True)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-new/stream",
|
||||
json={"message": "first message", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
# Redis set must have been called once with the NX flag.
|
||||
ns.redis.set.assert_called_once()
|
||||
call_kwargs = ns.redis.set.call_args
|
||||
assert call_kwargs.kwargs.get("nx") is True
|
||||
|
||||
|
||||
def test_stream_chat_dedup_skipped_for_non_user_messages(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""System/assistant messages (is_user_message=False) bypass the dedup
|
||||
guard — they are injected programmatically and must always be processed."""
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=None)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-sys/stream",
|
||||
json={"message": "system context", "is_user_message": False},
|
||||
)
|
||||
|
||||
# Even though redis_set_returns=None (would block a user message),
|
||||
# the endpoint must proceed because is_user_message=False.
|
||||
assert response.status_code == 200
|
||||
ns.redis.set.assert_not_called()
|
||||
|
||||
|
||||
def test_stream_chat_dedup_hash_uses_original_message_not_mutated(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The dedup hash must be computed from the original request message,
|
||||
not the mutated version that has the [Attached files] block appended.
|
||||
A file_id is sent so the route actually appends the [Attached files] block,
|
||||
exercising the mutation path — the hash must still match the original text."""
|
||||
import hashlib
|
||||
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=True)
|
||||
|
||||
file_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
# Mock workspace + prisma so the attachment block is actually appended.
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
fake_file = type(
|
||||
"F",
|
||||
(),
|
||||
{
|
||||
"id": file_id,
|
||||
"name": "doc.pdf",
|
||||
"mimeType": "application/pdf",
|
||||
"sizeBytes": 1024,
|
||||
},
|
||||
)()
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[fake_file])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-hash/stream",
|
||||
json={
|
||||
"message": "plain message",
|
||||
"is_user_message": True,
|
||||
"file_ids": [file_id],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
ns.redis.set.assert_called_once()
|
||||
call_args = ns.redis.set.call_args
|
||||
dedup_key = call_args.args[0]
|
||||
|
||||
# Hash must use the original message + sorted file IDs, not the mutated text.
|
||||
expected_hash = hashlib.sha256(
|
||||
f"sess-hash:plain message:{file_id}".encode()
|
||||
).hexdigest()[:16]
|
||||
expected_key = f"chat:msg_dedup:sess-hash:{expected_hash}"
|
||||
assert dedup_key == expected_key, (
|
||||
f"Dedup key {dedup_key!r} does not match expected {expected_key!r} — "
|
||||
"hash may be using mutated message or wrong inputs"
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat_dedup_key_released_after_stream_finish(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The dedup Redis key must be deleted after the turn completes (when
|
||||
subscriber_queue is None the route yields StreamFinish immediately and
|
||||
should release the key so the user can re-send the same message)."""
|
||||
from unittest.mock import AsyncMock as _AsyncMock
|
||||
|
||||
# Set up all internals manually so we can control subscribe_to_session.
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = _AsyncMock(return_value=None)
|
||||
# None → early-finish path: StreamFinish yielded immediately, dedup key released.
|
||||
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mock_redis = mocker.AsyncMock()
|
||||
mock_redis.set = _AsyncMock(return_value=True)
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=_AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-finish/stream",
|
||||
json={"message": "hello", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.text
|
||||
assert '"finish"' in body
|
||||
# The dedup key must be released so intentional re-sends are allowed.
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
def test_stream_chat_dedup_key_released_even_when_redis_delete_raises(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The route must not crash when the dedup Redis delete fails on the
|
||||
subscriber_queue-is-None early-finish path (except Exception: pass)."""
|
||||
from unittest.mock import AsyncMock as _AsyncMock
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = _AsyncMock(return_value=None)
|
||||
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mock_redis = mocker.AsyncMock()
|
||||
mock_redis.set = _AsyncMock(return_value=True)
|
||||
# Make the delete raise so the except-pass branch is exercised.
|
||||
mock_redis.delete = _AsyncMock(side_effect=RuntimeError("redis gone"))
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=_AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
|
||||
# Should not raise even though delete fails.
|
||||
response = client.post(
|
||||
"/sessions/sess-finish-err/stream",
|
||||
json={"message": "hello", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert '"finish"' in response.text
|
||||
# delete must have been attempted — the except-pass branch silenced the error.
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────
|
||||
|
||||
|
||||
@@ -980,3 +1007,59 @@ def test_disconnect_stream_returns_404_when_session_missing(
|
||||
|
||||
assert response.status_code == 404
|
||||
mock_disconnect.assert_not_awaited()
|
||||
|
||||
|
||||
# ─── GET /sessions/{session_id} — backward pagination ─────────────────────────
|
||||
|
||||
|
||||
def _make_paginated_messages(
|
||||
mocker: pytest_mock.MockerFixture, *, has_more: bool = False
|
||||
):
|
||||
"""Return a mock PaginatedMessages and configure the DB patch."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.copilot.db import PaginatedMessages
|
||||
from backend.copilot.model import ChatMessage, ChatSessionInfo, ChatSessionMetadata
|
||||
|
||||
now = datetime.now(UTC)
|
||||
session_info = ChatSessionInfo(
|
||||
session_id="sess-1",
|
||||
user_id=TEST_USER_ID,
|
||||
usage=[],
|
||||
started_at=now,
|
||||
updated_at=now,
|
||||
metadata=ChatSessionMetadata(),
|
||||
)
|
||||
page = PaginatedMessages(
|
||||
messages=[ChatMessage(role="user", content="hello", sequence=0)],
|
||||
has_more=has_more,
|
||||
oldest_sequence=0,
|
||||
session=session_info,
|
||||
)
|
||||
mock_paginate = mocker.patch(
|
||||
"backend.api.features.chat.routes.get_chat_messages_paginated",
|
||||
new_callable=AsyncMock,
|
||||
return_value=page,
|
||||
)
|
||||
return page, mock_paginate
|
||||
|
||||
|
||||
def test_get_session_returns_backward_paginated(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""All sessions use backward (newest-first) pagination."""
|
||||
_make_paginated_messages(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.get_active_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(None, None),
|
||||
)
|
||||
|
||||
response = client.get("/sessions/sess-1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["oldest_sequence"] == 0
|
||||
assert "forward_paginated" not in data
|
||||
assert "newest_sequence" not in data
|
||||
|
||||
@@ -12,6 +12,7 @@ import prisma.models
|
||||
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.data.graph as graph_db
|
||||
from backend.api.features.library.db import _fetch_schedule_info
|
||||
from backend.data.graph import GraphModel, GraphSettings
|
||||
from backend.data.includes import library_agent_include
|
||||
from backend.util.exceptions import NotFoundError
|
||||
@@ -117,4 +118,5 @@ async def add_graph_to_library(
|
||||
f"for store listing version #{store_listing_version_id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(added_agent)
|
||||
schedule_info = await _fetch_schedule_info(user_id, graph_id=graph_model.id)
|
||||
return library_model.LibraryAgent.from_db(added_agent, schedule_info=schedule_info)
|
||||
|
||||
@@ -21,13 +21,17 @@ async def test_add_graph_to_library_create_new_agent() -> None:
|
||||
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
||||
return_value=converted_agent,
|
||||
) as mock_from_db,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library._fetch_schedule_info",
|
||||
new=AsyncMock(return_value={}),
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.create = AsyncMock(return_value=created_agent)
|
||||
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is converted_agent
|
||||
mock_from_db.assert_called_once_with(created_agent)
|
||||
mock_from_db.assert_called_once_with(created_agent, schedule_info={})
|
||||
# Verify create was called with correct data
|
||||
create_call = mock_prisma.return_value.create.call_args
|
||||
create_data = create_call.kwargs["data"]
|
||||
@@ -54,6 +58,10 @@ async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
|
||||
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
||||
return_value=converted_agent,
|
||||
) as mock_from_db,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library._fetch_schedule_info",
|
||||
new=AsyncMock(return_value={}),
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.create = AsyncMock(
|
||||
side_effect=prisma.errors.UniqueViolationError(
|
||||
@@ -65,7 +73,7 @@ async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is converted_agent
|
||||
mock_from_db.assert_called_once_with(updated_agent)
|
||||
mock_from_db.assert_called_once_with(updated_agent, schedule_info={})
|
||||
# Verify update was called with correct where and data
|
||||
update_call = mock_prisma.return_value.update.call_args
|
||||
assert update_call.kwargs["where"] == {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import itertools
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional
|
||||
|
||||
import fastapi
|
||||
@@ -43,6 +44,65 @@ config = Config()
|
||||
integration_creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def _fetch_execution_counts(user_id: str, graph_ids: list[str]) -> dict[str, int]:
|
||||
"""Fetch execution counts per graph in a single batched query."""
|
||||
if not graph_ids:
|
||||
return {}
|
||||
rows = await prisma.models.AgentGraphExecution.prisma().group_by(
|
||||
by=["agentGraphId"],
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": {"in": graph_ids},
|
||||
"isDeleted": False,
|
||||
},
|
||||
count=True,
|
||||
)
|
||||
return {
|
||||
row["agentGraphId"]: int((row.get("_count") or {}).get("_all") or 0)
|
||||
for row in rows
|
||||
}
|
||||
|
||||
|
||||
async def _fetch_schedule_info(
|
||||
user_id: str, graph_id: Optional[str] = None
|
||||
) -> dict[str, str]:
|
||||
"""Fetch a map of graph_id → earliest next_run_time ISO string.
|
||||
|
||||
When `graph_id` is provided, the scheduler query is narrowed to that graph,
|
||||
which is cheaper for single-agent lookups (detail page, post-update, etc.).
|
||||
"""
|
||||
try:
|
||||
scheduler_client = get_scheduler_client()
|
||||
schedules = await scheduler_client.get_execution_schedules(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
earliest: dict[str, tuple[datetime, str]] = {}
|
||||
for s in schedules:
|
||||
parsed = _parse_iso_datetime(s.next_run_time)
|
||||
if parsed is None:
|
||||
continue
|
||||
current = earliest.get(s.graph_id)
|
||||
if current is None or parsed < current[0]:
|
||||
earliest[s.graph_id] = (parsed, s.next_run_time)
|
||||
return {graph_id: iso for graph_id, (_, iso) in earliest.items()}
|
||||
except Exception:
|
||||
logger.warning("Failed to fetch schedules for library agents", exc_info=True)
|
||||
return {}
|
||||
|
||||
|
||||
def _parse_iso_datetime(value: str) -> Optional[datetime]:
|
||||
"""Parse an ISO 8601 datetime, tolerating `Z` and naive forms (assumed UTC)."""
|
||||
try:
|
||||
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
logger.warning("Failed to parse schedule next_run_time: %s", value)
|
||||
return None
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed
|
||||
|
||||
|
||||
async def list_library_agents(
|
||||
user_id: str,
|
||||
search_term: Optional[str] = None,
|
||||
@@ -137,12 +197,22 @@ async def list_library_agents(
|
||||
|
||||
logger.debug(f"Retrieved {len(library_agents)} library agents for user #{user_id}")
|
||||
|
||||
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
|
||||
execution_counts, schedule_info = await asyncio.gather(
|
||||
_fetch_execution_counts(user_id, graph_ids),
|
||||
_fetch_schedule_info(user_id),
|
||||
)
|
||||
|
||||
# Only pass valid agents to the response
|
||||
valid_library_agents: list[library_model.LibraryAgent] = []
|
||||
|
||||
for agent in library_agents:
|
||||
try:
|
||||
library_agent = library_model.LibraryAgent.from_db(agent)
|
||||
library_agent = library_model.LibraryAgent.from_db(
|
||||
agent,
|
||||
execution_count_override=execution_counts.get(agent.agentGraphId),
|
||||
schedule_info=schedule_info,
|
||||
)
|
||||
valid_library_agents.append(library_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
@@ -214,12 +284,22 @@ async def list_favorite_library_agents(
|
||||
f"Retrieved {len(library_agents)} favorite library agents for user #{user_id}"
|
||||
)
|
||||
|
||||
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
|
||||
execution_counts, schedule_info = await asyncio.gather(
|
||||
_fetch_execution_counts(user_id, graph_ids),
|
||||
_fetch_schedule_info(user_id),
|
||||
)
|
||||
|
||||
# Only pass valid agents to the response
|
||||
valid_library_agents: list[library_model.LibraryAgent] = []
|
||||
|
||||
for agent in library_agents:
|
||||
try:
|
||||
library_agent = library_model.LibraryAgent.from_db(agent)
|
||||
library_agent = library_model.LibraryAgent.from_db(
|
||||
agent,
|
||||
execution_count_override=execution_counts.get(agent.agentGraphId),
|
||||
schedule_info=schedule_info,
|
||||
)
|
||||
valid_library_agents.append(library_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
@@ -285,6 +365,12 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
where={"userId": store_listing.owningUserId}
|
||||
)
|
||||
|
||||
schedule_info = (
|
||||
await _fetch_schedule_info(user_id, graph_id=library_agent.AgentGraph.id)
|
||||
if library_agent.AgentGraph
|
||||
else {}
|
||||
)
|
||||
|
||||
return library_model.LibraryAgent.from_db(
|
||||
library_agent,
|
||||
sub_graphs=(
|
||||
@@ -294,6 +380,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
),
|
||||
store_listing=store_listing,
|
||||
profile=profile,
|
||||
schedule_info=schedule_info,
|
||||
)
|
||||
|
||||
|
||||
@@ -329,7 +416,10 @@ async def get_library_agent_by_store_version_id(
|
||||
},
|
||||
include=library_agent_include(user_id),
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(agent) if agent else None
|
||||
if not agent:
|
||||
return None
|
||||
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId)
|
||||
return library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
|
||||
|
||||
|
||||
async def get_library_agent_by_graph_id(
|
||||
@@ -358,7 +448,10 @@ async def get_library_agent_by_graph_id(
|
||||
assert agent.AgentGraph # make type checker happy
|
||||
# Include sub-graphs so we can make a full credentials input schema
|
||||
sub_graphs = await graph_db.get_sub_graphs(agent.AgentGraph)
|
||||
return library_model.LibraryAgent.from_db(agent, sub_graphs=sub_graphs)
|
||||
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId)
|
||||
return library_model.LibraryAgent.from_db(
|
||||
agent, sub_graphs=sub_graphs, schedule_info=schedule_info
|
||||
)
|
||||
|
||||
|
||||
async def add_generated_agent_image(
|
||||
@@ -500,7 +593,11 @@ async def create_library_agent(
|
||||
for agent, graph in zip(library_agents, graph_entries):
|
||||
asyncio.create_task(add_generated_agent_image(graph, user_id, agent.id))
|
||||
|
||||
return [library_model.LibraryAgent.from_db(agent) for agent in library_agents]
|
||||
schedule_info = await _fetch_schedule_info(user_id)
|
||||
return [
|
||||
library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
|
||||
for agent in library_agents
|
||||
]
|
||||
|
||||
|
||||
async def update_agent_version_in_library(
|
||||
@@ -562,7 +659,8 @@ async def update_agent_version_in_library(
|
||||
f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}"
|
||||
)
|
||||
|
||||
return library_model.LibraryAgent.from_db(lib)
|
||||
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent_graph_id)
|
||||
return library_model.LibraryAgent.from_db(lib, schedule_info=schedule_info)
|
||||
|
||||
|
||||
async def create_graph_in_library(
|
||||
@@ -1467,7 +1565,11 @@ async def bulk_move_agents_to_folder(
|
||||
),
|
||||
)
|
||||
|
||||
return [library_model.LibraryAgent.from_db(agent) for agent in agents]
|
||||
schedule_info = await _fetch_schedule_info(user_id)
|
||||
return [
|
||||
library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
|
||||
for agent in agents
|
||||
]
|
||||
|
||||
|
||||
def collect_tree_ids(
|
||||
|
||||
@@ -65,6 +65,11 @@ async def test_get_library_agents(mocker):
|
||||
)
|
||||
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db._fetch_execution_counts",
|
||||
new=mocker.AsyncMock(return_value={}),
|
||||
)
|
||||
|
||||
# Call function
|
||||
result = await db.list_library_agents("test-user")
|
||||
|
||||
@@ -353,3 +358,136 @@ async def test_create_library_agent_uses_upsert():
|
||||
# Verify update branch restores soft-deleted/archived agents
|
||||
assert data["update"]["isDeleted"] is False
|
||||
assert data["update"]["isArchived"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_favorite_library_agents(mocker):
|
||||
mock_library_agents = [
|
||||
prisma.models.LibraryAgent(
|
||||
id="fav1",
|
||||
userId="test-user",
|
||||
agentGraphId="agent-fav",
|
||||
settings="{}", # type: ignore
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=True,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
id="agent-fav",
|
||||
version=1,
|
||||
name="Favorite Agent",
|
||||
description="My Favorite",
|
||||
userId="other-user",
|
||||
isActive=True,
|
||||
createdAt=datetime.now(),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||
return_value=mock_library_agents
|
||||
)
|
||||
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db._fetch_execution_counts",
|
||||
new=mocker.AsyncMock(return_value={"agent-fav": 7}),
|
||||
)
|
||||
|
||||
result = await db.list_favorite_library_agents("test-user")
|
||||
|
||||
assert len(result.agents) == 1
|
||||
assert result.agents[0].id == "fav1"
|
||||
assert result.agents[0].name == "Favorite Agent"
|
||||
assert result.agents[0].graph_id == "agent-fav"
|
||||
assert result.pagination.total_items == 1
|
||||
assert result.pagination.total_pages == 1
|
||||
assert result.pagination.current_page == 1
|
||||
assert result.pagination.page_size == 50
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_library_agents_skips_failed_agent(mocker):
|
||||
"""Agents that fail parsing should be skipped — covers the except branch."""
|
||||
mock_library_agents = [
|
||||
prisma.models.LibraryAgent(
|
||||
id="ua-bad",
|
||||
userId="test-user",
|
||||
agentGraphId="agent-bad",
|
||||
settings="{}", # type: ignore
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
id="agent-bad",
|
||||
version=1,
|
||||
name="Bad Agent",
|
||||
description="",
|
||||
userId="other-user",
|
||||
isActive=True,
|
||||
createdAt=datetime.now(),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||
return_value=mock_library_agents
|
||||
)
|
||||
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db._fetch_execution_counts",
|
||||
new=mocker.AsyncMock(return_value={}),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.library.model.LibraryAgent.from_db",
|
||||
side_effect=Exception("parse error"),
|
||||
)
|
||||
|
||||
result = await db.list_library_agents("test-user")
|
||||
|
||||
assert len(result.agents) == 0
|
||||
assert result.pagination.total_items == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_execution_counts_empty_graph_ids():
|
||||
result = await db._fetch_execution_counts("user-1", [])
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_execution_counts_uses_group_by(mocker):
|
||||
mock_prisma = mocker.patch("prisma.models.AgentGraphExecution.prisma")
|
||||
mock_prisma.return_value.group_by = mocker.AsyncMock(
|
||||
return_value=[
|
||||
{"agentGraphId": "graph-1", "_count": {"_all": 5}},
|
||||
{"agentGraphId": "graph-2", "_count": {"_all": 2}},
|
||||
]
|
||||
)
|
||||
|
||||
result = await db._fetch_execution_counts(
|
||||
"user-1", ["graph-1", "graph-2", "graph-3"]
|
||||
)
|
||||
|
||||
assert result == {"graph-1": 5, "graph-2": 2}
|
||||
mock_prisma.return_value.group_by.assert_called_once_with(
|
||||
by=["agentGraphId"],
|
||||
where={
|
||||
"userId": "user-1",
|
||||
"agentGraphId": {"in": ["graph-1", "graph-2", "graph-3"]},
|
||||
"isDeleted": False,
|
||||
},
|
||||
count=True,
|
||||
)
|
||||
|
||||
@@ -214,6 +214,14 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
folder_name: str | None = None # Denormalized for display
|
||||
|
||||
recommended_schedule_cron: str | None = None
|
||||
is_scheduled: bool = pydantic.Field(
|
||||
default=False,
|
||||
description="Whether this agent has active execution schedules",
|
||||
)
|
||||
next_scheduled_run: str | None = pydantic.Field(
|
||||
default=None,
|
||||
description="ISO 8601 timestamp of the next scheduled run, if any",
|
||||
)
|
||||
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
||||
marketplace_listing: Optional["MarketplaceListing"] = None
|
||||
|
||||
@@ -223,6 +231,8 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
sub_graphs: Optional[list[prisma.models.AgentGraph]] = None,
|
||||
store_listing: Optional[prisma.models.StoreListing] = None,
|
||||
profile: Optional[prisma.models.Profile] = None,
|
||||
execution_count_override: Optional[int] = None,
|
||||
schedule_info: Optional[dict[str, str]] = None,
|
||||
) -> "LibraryAgent":
|
||||
"""
|
||||
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
|
||||
@@ -258,10 +268,14 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
status = status_result.status
|
||||
new_output = status_result.new_output
|
||||
|
||||
execution_count = len(executions)
|
||||
execution_count = (
|
||||
execution_count_override
|
||||
if execution_count_override is not None
|
||||
else len(executions)
|
||||
)
|
||||
success_rate: float | None = None
|
||||
avg_correctness_score: float | None = None
|
||||
if execution_count > 0:
|
||||
if executions and execution_count > 0:
|
||||
success_count = sum(
|
||||
1
|
||||
for e in executions
|
||||
@@ -354,6 +368,10 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
folder_id=agent.folderId,
|
||||
folder_name=agent.Folder.name if agent.Folder else None,
|
||||
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
||||
is_scheduled=bool(schedule_info and agent.agentGraphId in schedule_info),
|
||||
next_scheduled_run=(
|
||||
schedule_info.get(agent.agentGraphId) if schedule_info else None
|
||||
),
|
||||
settings=_parse_settings(agent.settings),
|
||||
marketplace_listing=marketplace_listing_data,
|
||||
)
|
||||
|
||||
@@ -1,11 +1,66 @@
|
||||
import datetime
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
import pytest
|
||||
|
||||
from . import model as library_model
|
||||
|
||||
|
||||
def _make_library_agent(
|
||||
*,
|
||||
graph_id: str = "g1",
|
||||
executions: list | None = None,
|
||||
) -> prisma.models.LibraryAgent:
|
||||
return prisma.models.LibraryAgent(
|
||||
id="la1",
|
||||
userId="u1",
|
||||
agentGraphId=graph_id,
|
||||
settings="{}", # type: ignore
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=True,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=datetime.datetime.now(),
|
||||
updatedAt=datetime.datetime.now(),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
id=graph_id,
|
||||
version=1,
|
||||
name="Agent",
|
||||
description="Desc",
|
||||
userId="u1",
|
||||
isActive=True,
|
||||
createdAt=datetime.datetime.now(),
|
||||
Executions=executions,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_from_db_execution_count_override_covers_success_rate():
|
||||
"""Covers execution_count_override is not None branch and executions/count > 0 block."""
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
exec1 = prisma.models.AgentGraphExecution(
|
||||
id="exec-1",
|
||||
agentGraphId="g1",
|
||||
agentGraphVersion=1,
|
||||
userId="u1",
|
||||
executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED,
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
isDeleted=False,
|
||||
isShared=False,
|
||||
)
|
||||
agent = _make_library_agent(executions=[exec1])
|
||||
|
||||
result = library_model.LibraryAgent.from_db(agent, execution_count_override=1)
|
||||
|
||||
assert result.execution_count == 1
|
||||
assert result.success_rate is not None
|
||||
assert result.success_rate == 100.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_preset_from_db(test_user_id: str):
|
||||
# Create mock DB agent
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,8 @@ import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Any, Literal, Sequence, get_args
|
||||
from typing import Annotated, Any, Literal, Sequence, cast, get_args
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pydantic
|
||||
import stripe
|
||||
@@ -25,7 +26,7 @@ from fastapi import (
|
||||
)
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from prisma.enums import SubscriptionTier
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
@@ -48,17 +49,24 @@ from backend.data.auth import api_key as api_key_db
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
PendingChangeUnknown,
|
||||
RefundRequest,
|
||||
TransactionHistory,
|
||||
UserCredit,
|
||||
cancel_stripe_subscription,
|
||||
create_subscription_checkout,
|
||||
get_auto_top_up,
|
||||
get_pending_subscription_change,
|
||||
get_proration_credit_cents,
|
||||
get_subscription_price_id,
|
||||
get_user_credit_model,
|
||||
handle_subscription_payment_failure,
|
||||
modify_stripe_subscription_for_tier,
|
||||
release_pending_subscription_schedule,
|
||||
set_auto_top_up,
|
||||
set_subscription_tier,
|
||||
sync_subscription_from_stripe,
|
||||
sync_subscription_schedule_from_stripe,
|
||||
)
|
||||
from backend.data.graph import GraphSettings
|
||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||
@@ -694,14 +702,83 @@ class SubscriptionTierRequest(BaseModel):
|
||||
cancel_url: str = ""
|
||||
|
||||
|
||||
class SubscriptionCheckoutResponse(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
tier: str
|
||||
monthly_cost: int
|
||||
tier_costs: dict[str, int]
|
||||
tier: Literal["FREE", "PRO", "BUSINESS", "ENTERPRISE"]
|
||||
monthly_cost: int # amount in cents (Stripe convention)
|
||||
tier_costs: dict[str, int] # tier name -> amount in cents
|
||||
proration_credit_cents: int # unused portion of current sub to convert on upgrade
|
||||
pending_tier: Optional[Literal["FREE", "PRO", "BUSINESS"]] = None
|
||||
pending_tier_effective_at: Optional[datetime] = None
|
||||
url: str = Field(
|
||||
default="",
|
||||
description=(
|
||||
"Populated only when POST /credits/subscription starts a Stripe Checkout"
|
||||
" Session (FREE → paid upgrade). Empty string in all other branches —"
|
||||
" the client redirects to this URL when non-empty."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _validate_checkout_redirect_url(url: str) -> bool:
|
||||
"""Return True if `url` matches the configured frontend origin.
|
||||
|
||||
Prevents open-redirect: attackers must not be able to supply arbitrary
|
||||
success_url/cancel_url that Stripe will redirect users to after checkout.
|
||||
|
||||
Pre-parse rejection rules (applied before urlparse):
|
||||
- Backslashes (``\\``) are normalised differently across parsers/browsers.
|
||||
- Control characters (U+0000–U+001F) are not valid in URLs and may confuse
|
||||
some URL-parsing implementations.
|
||||
"""
|
||||
# Reject characters that can confuse URL parsers before any parsing.
|
||||
if "\\" in url:
|
||||
return False
|
||||
if any(ord(c) < 0x20 for c in url):
|
||||
return False
|
||||
|
||||
allowed = settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
if not allowed:
|
||||
# No configured origin — refuse to validate rather than allow arbitrary URLs.
|
||||
return False
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
allowed_parsed = urlparse(allowed)
|
||||
except ValueError:
|
||||
return False
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return False
|
||||
# Reject ``user:pass@host`` authority tricks — ``@`` in the netloc component
|
||||
# can trick browsers into connecting to a different host than displayed.
|
||||
# ``@`` in query/fragment is harmless and must be allowed.
|
||||
if "@" in parsed.netloc:
|
||||
return False
|
||||
return (
|
||||
parsed.scheme == allowed_parsed.scheme
|
||||
and parsed.netloc == allowed_parsed.netloc
|
||||
)
|
||||
|
||||
|
||||
@cached(ttl_seconds=300, maxsize=32, cache_none=False)
|
||||
async def _get_stripe_price_amount(price_id: str) -> int | None:
|
||||
"""Return the unit_amount (cents) for a Stripe Price ID, cached for 5 minutes.
|
||||
|
||||
Returns ``None`` on transient Stripe errors. ``cache_none=False`` opts out
|
||||
of caching the ``None`` sentinel so the next request retries Stripe instead
|
||||
of being served a stale "no price" for the rest of the TTL window. Callers
|
||||
should treat ``None`` as an unknown price and fall back to 0.
|
||||
|
||||
Stripe prices rarely change; caching avoids a ~200-600 ms Stripe round-trip on
|
||||
every GET /credits/subscription page load and reduces quota consumption.
|
||||
"""
|
||||
try:
|
||||
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
|
||||
return price.unit_amount or 0
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"Failed to retrieve Stripe price %s — returning None (not cached)",
|
||||
price_id,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -722,27 +799,57 @@ async def get_subscription_status(
|
||||
*[get_subscription_price_id(t) for t in paid_tiers]
|
||||
)
|
||||
|
||||
tier_costs: dict[str, int] = {"FREE": 0, "ENTERPRISE": 0}
|
||||
for t, price_id in zip(paid_tiers, price_ids):
|
||||
cost = 0
|
||||
if price_id:
|
||||
try:
|
||||
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
|
||||
cost = price.unit_amount or 0
|
||||
except stripe.StripeError:
|
||||
pass
|
||||
tier_costs: dict[str, int] = {
|
||||
SubscriptionTier.FREE.value: 0,
|
||||
SubscriptionTier.ENTERPRISE.value: 0,
|
||||
}
|
||||
|
||||
async def _cost(pid: str | None) -> int:
|
||||
return (await _get_stripe_price_amount(pid) or 0) if pid else 0
|
||||
|
||||
costs = await asyncio.gather(*[_cost(pid) for pid in price_ids])
|
||||
for t, cost in zip(paid_tiers, costs):
|
||||
tier_costs[t.value] = cost
|
||||
|
||||
return SubscriptionStatusResponse(
|
||||
current_monthly_cost = tier_costs.get(tier.value, 0)
|
||||
proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost)
|
||||
|
||||
try:
|
||||
pending = await get_pending_subscription_change(user_id)
|
||||
except (stripe.StripeError, PendingChangeUnknown):
|
||||
# Swallow Stripe-side failures (rate limits, transient network) AND
|
||||
# PendingChangeUnknown (LaunchDarkly price-id lookup failed). Both
|
||||
# propagate past the cache so the next request retries fresh instead
|
||||
# of serving a stale None for the TTL window. Let real bugs (KeyError,
|
||||
# AttributeError, etc.) propagate so they surface in Sentry.
|
||||
logger.exception(
|
||||
"get_subscription_status: failed to resolve pending change for user %s",
|
||||
user_id,
|
||||
)
|
||||
pending = None
|
||||
|
||||
response = SubscriptionStatusResponse(
|
||||
tier=tier.value,
|
||||
monthly_cost=tier_costs.get(tier.value, 0),
|
||||
monthly_cost=current_monthly_cost,
|
||||
tier_costs=tier_costs,
|
||||
proration_credit_cents=proration_credit,
|
||||
)
|
||||
if pending is not None:
|
||||
pending_tier_enum, pending_effective_at = pending
|
||||
if pending_tier_enum == SubscriptionTier.FREE:
|
||||
response.pending_tier = "FREE"
|
||||
elif pending_tier_enum == SubscriptionTier.PRO:
|
||||
response.pending_tier = "PRO"
|
||||
elif pending_tier_enum == SubscriptionTier.BUSINESS:
|
||||
response.pending_tier = "BUSINESS"
|
||||
if response.pending_tier is not None:
|
||||
response.pending_tier_effective_at = pending_effective_at
|
||||
return response
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/credits/subscription",
|
||||
summary="Start a Stripe Checkout session to upgrade subscription tier",
|
||||
summary="Update subscription tier or start a Stripe Checkout session",
|
||||
operation_id="updateSubscriptionTier",
|
||||
tags=["credits"],
|
||||
dependencies=[Security(requires_user)],
|
||||
@@ -750,7 +857,7 @@ async def get_subscription_status(
|
||||
async def update_subscription_tier(
|
||||
request: SubscriptionTierRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> SubscriptionCheckoutResponse:
|
||||
) -> SubscriptionStatusResponse:
|
||||
# Pydantic validates tier is one of FREE/PRO/BUSINESS via Literal type.
|
||||
tier = SubscriptionTier(request.tier)
|
||||
|
||||
@@ -762,28 +869,143 @@ async def update_subscription_tier(
|
||||
detail="ENTERPRISE subscription changes must be managed by an administrator",
|
||||
)
|
||||
|
||||
# Same-tier request = "stay on my current tier" = cancel any pending
|
||||
# scheduled change (paid→paid downgrade or paid→FREE cancel). This is the
|
||||
# collapsed behaviour that replaces the old /credits/subscription/cancel-pending
|
||||
# route. Safe when no pending change exists: release_pending_subscription_schedule
|
||||
# returns False and we simply return the current status.
|
||||
if (user.subscription_tier or SubscriptionTier.FREE) == tier:
|
||||
try:
|
||||
await release_pending_subscription_schedule(user_id)
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error releasing pending subscription change for user %s: %s",
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to cancel the pending subscription change right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
return await get_subscription_status(user_id)
|
||||
|
||||
payment_enabled = await is_feature_enabled(
|
||||
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
|
||||
)
|
||||
|
||||
# Downgrade to FREE: cancel active Stripe subscription, then update the DB tier.
|
||||
# Downgrade to FREE: schedule Stripe cancellation at period end so the user
|
||||
# keeps their tier for the time they already paid for. The DB tier is NOT
|
||||
# updated here when a subscription exists — the customer.subscription.deleted
|
||||
# webhook fires at period end and downgrades to FREE then.
|
||||
# Exception: if the user has no active Stripe subscription (e.g. admin-granted
|
||||
# tier), cancel_stripe_subscription returns False and we update the DB tier
|
||||
# immediately since no webhook will ever fire.
|
||||
# When payment is disabled entirely, update the DB tier directly.
|
||||
if tier == SubscriptionTier.FREE:
|
||||
if payment_enabled:
|
||||
await cancel_stripe_subscription(user_id)
|
||||
try:
|
||||
had_subscription = await cancel_stripe_subscription(user_id)
|
||||
except stripe.StripeError as e:
|
||||
# Log full Stripe error server-side but return a generic message
|
||||
# to the client — raw Stripe errors can leak customer/sub IDs and
|
||||
# infrastructure config details.
|
||||
logger.exception(
|
||||
"Stripe error cancelling subscription for user %s: %s",
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to cancel your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
if not had_subscription:
|
||||
# No active Stripe subscription found — the user was on an
|
||||
# admin-granted tier. Update DB immediately since the
|
||||
# subscription.deleted webhook will never fire.
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return await get_subscription_status(user_id)
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
return await get_subscription_status(user_id)
|
||||
|
||||
# Beta users (payment not enabled) → update tier directly without Stripe.
|
||||
# Paid tier changes require payment to be enabled — block self-service upgrades
|
||||
# when the flag is off. Admins use the /api/admin/ routes to set tiers directly.
|
||||
if not payment_enabled:
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Subscription not available for tier {tier}",
|
||||
)
|
||||
|
||||
# Paid upgrade → create Stripe Checkout Session.
|
||||
# Paid→paid tier change: if the user already has a Stripe subscription,
|
||||
# modify it in-place with proration instead of creating a new Checkout
|
||||
# Session. This preserves remaining paid time and avoids double-charging.
|
||||
# The customer.subscription.updated webhook fires and updates the DB tier.
|
||||
current_tier = user.subscription_tier or SubscriptionTier.FREE
|
||||
if current_tier in (SubscriptionTier.PRO, SubscriptionTier.BUSINESS):
|
||||
try:
|
||||
modified = await modify_stripe_subscription_for_tier(user_id, tier)
|
||||
if modified:
|
||||
return await get_subscription_status(user_id)
|
||||
# modify_stripe_subscription_for_tier returns False when no active
|
||||
# Stripe subscription exists — i.e. the user has an admin-granted
|
||||
# paid tier with no Stripe record. In that case, update the DB
|
||||
# tier directly (same as the FREE-downgrade path for admin-granted
|
||||
# users) rather than sending them through a new Checkout Session.
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return await get_subscription_status(user_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error modifying subscription for user %s: %s", user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to update your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
|
||||
# Paid upgrade from FREE → create Stripe Checkout Session.
|
||||
if not request.success_url or not request.cancel_url:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="success_url and cancel_url are required for paid tier upgrades",
|
||||
)
|
||||
# Open-redirect protection: both URLs must point to the configured frontend
|
||||
# origin, otherwise an attacker could use our Stripe integration as a
|
||||
# redirector to arbitrary phishing sites.
|
||||
#
|
||||
# Fail early with a clear 503 if the server is misconfigured (neither
|
||||
# frontend_base_url nor platform_base_url set), so operators get an
|
||||
# actionable error instead of the misleading "must match the platform
|
||||
# frontend origin" 422 that _validate_checkout_redirect_url would otherwise
|
||||
# produce when `allowed` is empty.
|
||||
if not (settings.config.frontend_base_url or settings.config.platform_base_url):
|
||||
logger.error(
|
||||
"update_subscription_tier: neither frontend_base_url nor "
|
||||
"platform_base_url is configured; cannot validate checkout redirect URLs"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"Payment redirect URLs cannot be validated: "
|
||||
"frontend_base_url or platform_base_url must be set on the server."
|
||||
),
|
||||
)
|
||||
if not _validate_checkout_redirect_url(
|
||||
request.success_url
|
||||
) or not _validate_checkout_redirect_url(request.cancel_url):
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="success_url and cancel_url must match the platform frontend origin",
|
||||
)
|
||||
try:
|
||||
url = await create_subscription_checkout(
|
||||
user_id=user_id,
|
||||
@@ -791,54 +1013,113 @@ async def update_subscription_tier(
|
||||
success_url=request.success_url,
|
||||
cancel_url=request.cancel_url,
|
||||
)
|
||||
except (ValueError, stripe.StripeError) as e:
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error creating checkout session for user %s: %s", user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to start checkout right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
|
||||
return SubscriptionCheckoutResponse(url=url)
|
||||
status = await get_subscription_status(user_id)
|
||||
status.url = url
|
||||
return status
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"]
|
||||
)
|
||||
async def stripe_webhook(request: Request):
|
||||
webhook_secret = settings.secrets.stripe_webhook_secret
|
||||
if not webhook_secret:
|
||||
# Guard: an empty secret allows HMAC forgery (attacker can compute a valid
|
||||
# signature over the same empty key). Reject all webhook calls when unconfigured.
|
||||
logger.error(
|
||||
"stripe_webhook: STRIPE_WEBHOOK_SECRET is not configured — "
|
||||
"rejecting request to prevent signature bypass"
|
||||
)
|
||||
raise HTTPException(status_code=503, detail="Webhook not configured")
|
||||
|
||||
# Get the raw request body
|
||||
payload = await request.body()
|
||||
# Get the signature header
|
||||
sig_header = request.headers.get("stripe-signature")
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, settings.secrets.stripe_webhook_secret
|
||||
)
|
||||
except ValueError as e:
|
||||
event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret)
|
||||
except ValueError:
|
||||
# Invalid payload
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid payload: {str(e) or type(e).__name__}"
|
||||
)
|
||||
except stripe.SignatureVerificationError as e:
|
||||
raise HTTPException(status_code=400, detail="Invalid payload")
|
||||
except stripe.SignatureVerificationError:
|
||||
# Invalid signature
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid signature: {str(e) or type(e).__name__}"
|
||||
raise HTTPException(status_code=400, detail="Invalid signature")
|
||||
|
||||
# Defensive payload extraction. A malformed payload (missing/non-dict
|
||||
# `data.object`, missing `id`) would otherwise raise KeyError/TypeError
|
||||
# AFTER signature verification — which Stripe interprets as a delivery
|
||||
# failure and retries forever, while spamming Sentry with no useful info.
|
||||
# Acknowledge with 200 and a warning so Stripe stops retrying.
|
||||
event_type = event.get("type", "")
|
||||
event_data = event.get("data") or {}
|
||||
data_object = event_data.get("object") if isinstance(event_data, dict) else None
|
||||
if not isinstance(data_object, dict):
|
||||
logger.warning(
|
||||
"stripe_webhook: %s missing or non-dict data.object; ignoring",
|
||||
event_type,
|
||||
)
|
||||
return Response(status_code=200)
|
||||
|
||||
if (
|
||||
event["type"] == "checkout.session.completed"
|
||||
or event["type"] == "checkout.session.async_payment_succeeded"
|
||||
if event_type in (
|
||||
"checkout.session.completed",
|
||||
"checkout.session.async_payment_succeeded",
|
||||
):
|
||||
await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"])
|
||||
session_id = data_object.get("id")
|
||||
if not session_id:
|
||||
logger.warning(
|
||||
"stripe_webhook: %s missing data.object.id; ignoring", event_type
|
||||
)
|
||||
return Response(status_code=200)
|
||||
await UserCredit().fulfill_checkout(session_id=session_id)
|
||||
|
||||
if event["type"] in (
|
||||
if event_type in (
|
||||
"customer.subscription.created",
|
||||
"customer.subscription.updated",
|
||||
"customer.subscription.deleted",
|
||||
):
|
||||
await sync_subscription_from_stripe(event["data"]["object"])
|
||||
await sync_subscription_from_stripe(data_object)
|
||||
|
||||
if event["type"] == "charge.dispute.created":
|
||||
await UserCredit().handle_dispute(event["data"]["object"])
|
||||
# `subscription_schedule.updated` is deliberately omitted: our own
|
||||
# `SubscriptionSchedule.create` + `.modify` calls in
|
||||
# `_schedule_downgrade_at_period_end` would fire that event right back at us
|
||||
# and loop redundant traffic through this handler. We only care about state
|
||||
# transitions (released / completed); phase advance to the new price is
|
||||
# already covered by `customer.subscription.updated`.
|
||||
if event_type in (
|
||||
"subscription_schedule.released",
|
||||
"subscription_schedule.completed",
|
||||
):
|
||||
await sync_subscription_schedule_from_stripe(data_object)
|
||||
|
||||
if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed":
|
||||
await UserCredit().deduct_credits(event["data"]["object"])
|
||||
if event_type == "invoice.payment_failed":
|
||||
await handle_subscription_payment_failure(data_object)
|
||||
|
||||
# `handle_dispute` and `deduct_credits` expect Stripe SDK typed objects
|
||||
# (Dispute/Refund). The Stripe webhook payload's `data.object` is a
|
||||
# StripeObject (a dict subclass) carrying that runtime shape, so we cast
|
||||
# to satisfy the type checker without changing runtime behaviour.
|
||||
if event_type == "charge.dispute.created":
|
||||
await UserCredit().handle_dispute(cast(stripe.Dispute, data_object))
|
||||
|
||||
if event_type == "refund.created" or event_type == "charge.dispute.closed":
|
||||
await UserCredit().deduct_credits(
|
||||
cast("stripe.Refund | stripe.Dispute", data_object)
|
||||
)
|
||||
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from backend.copilot.permissions import (
|
||||
validate_block_identifiers,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import ExecutionContext
|
||||
@@ -32,9 +33,36 @@ logger = logging.getLogger(__name__)
|
||||
# Block ID shared between autopilot.py and copilot prompting.py.
|
||||
AUTOPILOT_BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
|
||||
# Identifiers used when registering an AutoPilotBlock turn with the
|
||||
# stream registry — distinguishes block-originated turns from sub-session
|
||||
# or HTTP SSE turns in logs / observability.
|
||||
_AUTOPILOT_TOOL_CALL_ID = "autopilot_block"
|
||||
_AUTOPILOT_TOOL_NAME = "autopilot_block"
|
||||
|
||||
class SubAgentRecursionError(RuntimeError):
|
||||
"""Raised when the sub-agent nesting depth limit is exceeded."""
|
||||
# Ceiling on how long AutoPilotBlock.execute_copilot will wait for the
|
||||
# enqueued turn's terminal event. Graph blocks run synchronously from
|
||||
# the caller's perspective so we wait effectively as long as needed; 6h
|
||||
# matches the previous abandoned-task cap and is much longer than any
|
||||
# legitimate AutoPilot turn.
|
||||
_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS = 6 * 60 * 60 # 6 hours
|
||||
|
||||
|
||||
class SubAgentRecursionError(BlockExecutionError):
|
||||
"""Raised when the AutoPilot sub-agent nesting depth limit is exceeded.
|
||||
|
||||
Inherits :class:`BlockExecutionError` — this is a known, handled
|
||||
runtime failure at the block level (caller nested AutoPilotBlocks
|
||||
beyond the configured limit). Surfaces with the block_name /
|
||||
block_id the block framework expects, instead of being wrapped in
|
||||
``BlockUnknownError``.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(
|
||||
message=message,
|
||||
block_name="AutoPilotBlock",
|
||||
block_id=AUTOPILOT_BLOCK_ID,
|
||||
)
|
||||
|
||||
|
||||
class ToolCallEntry(TypedDict):
|
||||
@@ -268,11 +296,15 @@ class AutoPilotBlock(Block):
|
||||
user_id: str,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> tuple[str, list[ToolCallEntry], str, str, TokenUsage]:
|
||||
"""Invoke the copilot and collect all stream results.
|
||||
"""Invoke the copilot on the copilot_executor queue and aggregate the
|
||||
result.
|
||||
|
||||
Delegates to :func:`collect_copilot_response` — the shared helper that
|
||||
consumes ``stream_chat_completion_sdk`` without wrapping it in an
|
||||
``asyncio.timeout`` (the SDK manages its own heartbeat-based timeouts).
|
||||
Delegates to :func:`run_copilot_turn_via_queue` — the shared
|
||||
primitive used by ``run_sub_session`` too — which creates the
|
||||
stream_registry meta record, enqueues the job, and waits on the
|
||||
Redis stream for the terminal event. Any available
|
||||
copilot_executor worker picks up the job, so this call survives
|
||||
the graph-executor worker dying mid-turn (RabbitMQ redelivers).
|
||||
|
||||
Args:
|
||||
prompt: The user task/instruction.
|
||||
@@ -285,8 +317,8 @@ class AutoPilotBlock(Block):
|
||||
Returns:
|
||||
A tuple of (response_text, tool_calls, history_json, session_id, usage).
|
||||
"""
|
||||
from backend.copilot.sdk.collect import (
|
||||
collect_copilot_response, # avoid circular import
|
||||
from backend.copilot.sdk.session_waiter import (
|
||||
run_copilot_turn_via_queue, # avoid circular import
|
||||
)
|
||||
|
||||
tokens = _check_recursion(max_recursion_depth)
|
||||
@@ -299,14 +331,35 @@ class AutoPilotBlock(Block):
|
||||
if system_context:
|
||||
effective_prompt = f"[System Context: {system_context}]\n\n{prompt}"
|
||||
|
||||
result = await collect_copilot_response(
|
||||
outcome, result = await run_copilot_turn_via_queue(
|
||||
session_id=session_id,
|
||||
message=effective_prompt,
|
||||
user_id=user_id,
|
||||
message=effective_prompt,
|
||||
# Graph block execution is synchronous from the caller's
|
||||
# perspective — wait effectively as long as needed. The
|
||||
# SDK enforces its own idle-based timeout inside the
|
||||
# stream_registry pipeline.
|
||||
timeout=_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS,
|
||||
permissions=effective_permissions,
|
||||
tool_call_id=_AUTOPILOT_TOOL_CALL_ID,
|
||||
tool_name=_AUTOPILOT_TOOL_NAME,
|
||||
)
|
||||
if outcome == "failed":
|
||||
raise RuntimeError(
|
||||
"AutoPilot turn failed — see the session's transcript"
|
||||
)
|
||||
if outcome == "running":
|
||||
raise RuntimeError(
|
||||
"AutoPilot turn did not complete within "
|
||||
f"{_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS}s — session "
|
||||
f"{session_id}"
|
||||
)
|
||||
|
||||
# Build a lightweight conversation summary from streamed data.
|
||||
# Build a lightweight conversation summary from the aggregated data.
|
||||
# When ``result.queued`` is True the prompt rode on an already-
|
||||
# in-flight turn (``run_copilot_turn_via_queue`` queued it and
|
||||
# waited on the existing turn's stream); the aggregated result
|
||||
# is still valid, so the same rendering path applies.
|
||||
turn_messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": effective_prompt},
|
||||
]
|
||||
@@ -315,7 +368,7 @@ class AutoPilotBlock(Block):
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": result.response_text,
|
||||
"tool_calls": result.tool_calls,
|
||||
"tool_calls": [tc.model_dump() for tc in result.tool_calls],
|
||||
}
|
||||
)
|
||||
else:
|
||||
@@ -326,11 +379,11 @@ class AutoPilotBlock(Block):
|
||||
|
||||
tool_calls: list[ToolCallEntry] = [
|
||||
{
|
||||
"tool_call_id": tc["tool_call_id"],
|
||||
"tool_name": tc["tool_name"],
|
||||
"input": tc["input"],
|
||||
"output": tc["output"],
|
||||
"success": tc["success"],
|
||||
"tool_call_id": tc.tool_call_id,
|
||||
"tool_name": tc.tool_name,
|
||||
"input": tc.input,
|
||||
"output": tc.output,
|
||||
"success": tc.success,
|
||||
}
|
||||
for tc in result.tool_calls
|
||||
]
|
||||
|
||||
@@ -106,7 +106,6 @@ class LlmModelMeta(EnumMeta):
|
||||
|
||||
|
||||
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> "LlmModel | None":
|
||||
"""Handle provider-prefixed model names like 'anthropic/claude-sonnet-4-6'."""
|
||||
@@ -203,6 +202,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
GROK_4 = "x-ai/grok-4"
|
||||
GROK_4_FAST = "x-ai/grok-4-fast"
|
||||
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
||||
GROK_4_20 = "x-ai/grok-4.20"
|
||||
GROK_4_20_MULTI_AGENT = "x-ai/grok-4.20-multi-agent"
|
||||
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
|
||||
KIMI_K2 = "moonshotai/kimi-k2"
|
||||
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
|
||||
@@ -627,6 +628,18 @@ MODEL_METADATA = {
|
||||
LlmModel.GROK_4_1_FAST: ModelMetadata(
|
||||
"open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1
|
||||
),
|
||||
LlmModel.GROK_4_20: ModelMetadata(
|
||||
"open_router", 2000000, 100000, "Grok 4.20", "OpenRouter", "xAI", 3
|
||||
),
|
||||
LlmModel.GROK_4_20_MULTI_AGENT: ModelMetadata(
|
||||
"open_router",
|
||||
2000000,
|
||||
100000,
|
||||
"Grok 4.20 Multi-Agent",
|
||||
"OpenRouter",
|
||||
"xAI",
|
||||
3,
|
||||
),
|
||||
LlmModel.GROK_CODE_FAST_1: ModelMetadata(
|
||||
"open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1
|
||||
),
|
||||
@@ -987,7 +1000,6 @@ async def llm_call(
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
|
||||
an_tools = convert_openai_tool_fmt_to_anthropic(tools)
|
||||
# Cache tool definitions alongside the system prompt.
|
||||
# Placing cache_control on the last tool caches all tool schemas as a
|
||||
|
||||
230
autogpt_platform/backend/backend/copilot/baseline/reasoning.py
Normal file
230
autogpt_platform/backend/backend/copilot/baseline/reasoning.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""Extended-thinking wire support for the baseline (OpenRouter) path.
|
||||
|
||||
Anthropic routes on OpenRouter expose extended thinking through
|
||||
non-OpenAI extension fields that the OpenAI Python SDK doesn't model:
|
||||
|
||||
* ``reasoning`` (legacy string) — enabled by ``include_reasoning: true``.
|
||||
* ``reasoning_content`` — DeepSeek / some OpenRouter routes.
|
||||
* ``reasoning_details`` — structured list shipped with the unified
|
||||
``reasoning`` request param.
|
||||
|
||||
This module keeps the wire-level concerns in one place:
|
||||
|
||||
* :class:`OpenRouterDeltaExtension` validates the extension dict pulled off
|
||||
``ChoiceDelta.model_extra`` into typed pydantic models — no ``getattr`` +
|
||||
``isinstance`` duck-typing at the call site.
|
||||
* :class:`BaselineReasoningEmitter` owns the reasoning block lifecycle for
|
||||
one streaming round and emits ``StreamReasoning*`` events so the caller
|
||||
only has to plumb the events into its pending queue.
|
||||
* :func:`reasoning_extra_body` builds the ``extra_body`` fragment for the
|
||||
OpenAI client call. Returns ``None`` on non-Anthropic routes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamReasoningDelta,
|
||||
StreamReasoningEnd,
|
||||
StreamReasoningStart,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_VISIBLE_REASONING_TYPES = frozenset({"reasoning.text", "reasoning.summary"})
|
||||
|
||||
|
||||
class ReasoningDetail(BaseModel):
|
||||
"""One entry in OpenRouter's ``reasoning_details`` list.
|
||||
|
||||
OpenRouter ships ``type: "reasoning.text"`` / ``"reasoning.summary"`` /
|
||||
``"reasoning.encrypted"`` entries. Only the first two carry
|
||||
user-visible text; encrypted entries are opaque and omitted from the
|
||||
rendered collapse. Unknown future types are tolerated (``extra="ignore"``)
|
||||
so an upstream addition doesn't crash the stream — but their ``text`` /
|
||||
``summary`` fields are NOT surfaced because they may carry provider
|
||||
metadata rather than user-visible reasoning (see
|
||||
:attr:`visible_text`).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
type: str | None = None
|
||||
text: str | None = None
|
||||
summary: str | None = None
|
||||
|
||||
@property
|
||||
def visible_text(self) -> str:
|
||||
"""Return the human-readable text for this entry, or ``""``.
|
||||
|
||||
Only entries with a recognised reasoning type (``reasoning.text`` /
|
||||
``reasoning.summary``) surface text; unknown or encrypted types
|
||||
return an empty string even if they carry a ``text`` /
|
||||
``summary`` field, to guard against future provider metadata
|
||||
being rendered as reasoning in the UI. Entries missing a
|
||||
``type`` are treated as text (pre-``reasoning_details`` OpenRouter
|
||||
payloads omit the field).
|
||||
"""
|
||||
if self.type is not None and self.type not in _VISIBLE_REASONING_TYPES:
|
||||
return ""
|
||||
return self.text or self.summary or ""
|
||||
|
||||
|
||||
class OpenRouterDeltaExtension(BaseModel):
|
||||
"""Non-OpenAI fields OpenRouter adds to streaming deltas.
|
||||
|
||||
Instantiate via :meth:`from_delta` which pulls the extension dict off
|
||||
``ChoiceDelta.model_extra`` (where pydantic v2 stashes fields that
|
||||
aren't part of the declared schema) and validates it through this
|
||||
model. That keeps the parser honest — malformed entries surface as
|
||||
validation errors rather than silent ``None``-coalesce bugs — and
|
||||
avoids the ``getattr`` + ``isinstance`` duck-typing the earlier inline
|
||||
extractor relied on.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
reasoning: str | None = None
|
||||
reasoning_content: str | None = None
|
||||
reasoning_details: list[ReasoningDetail] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_delta(cls, delta: ChoiceDelta) -> "OpenRouterDeltaExtension":
|
||||
"""Build an extension view from ``delta.model_extra``.
|
||||
|
||||
Malformed provider payloads (e.g. ``reasoning_details`` shipped as
|
||||
a string rather than a list) surface as a ``ValidationError`` which
|
||||
is logged and swallowed — returning an empty extension so the rest
|
||||
of the stream (valid text / tool calls) keeps flowing. An optional
|
||||
feature's corrupted wire data must never abort the whole stream.
|
||||
"""
|
||||
try:
|
||||
return cls.model_validate(delta.model_extra or {})
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"[Baseline] Dropping malformed OpenRouter reasoning payload: %s",
|
||||
exc,
|
||||
)
|
||||
return cls()
|
||||
|
||||
def visible_text(self) -> str:
|
||||
"""Concatenated reasoning text, pulled from whichever channel is set.
|
||||
|
||||
Priority: the legacy ``reasoning`` string, then DeepSeek's
|
||||
``reasoning_content``, then the concatenation of text-bearing
|
||||
entries in ``reasoning_details``. Only one channel is set per
|
||||
provider in practice; the priority order just makes the fallback
|
||||
deterministic if a provider ever emits multiple.
|
||||
"""
|
||||
if self.reasoning:
|
||||
return self.reasoning
|
||||
if self.reasoning_content:
|
||||
return self.reasoning_content
|
||||
return "".join(d.visible_text for d in self.reasoning_details)
|
||||
|
||||
|
||||
def reasoning_extra_body(model: str, max_thinking_tokens: int) -> dict[str, Any] | None:
|
||||
"""Build the ``extra_body["reasoning"]`` fragment for the OpenAI client.
|
||||
|
||||
Returns ``None`` for non-Anthropic routes (other OpenRouter providers
|
||||
ignore the field but we skip it anyway to keep the payload minimal)
|
||||
and for ``max_thinking_tokens <= 0`` (operator kill switch).
|
||||
"""
|
||||
# Imported lazily to avoid pulling service.py at module load — service.py
|
||||
# imports this module, and the lazy import keeps the dependency one-way.
|
||||
from backend.copilot.baseline.service import _is_anthropic_model
|
||||
|
||||
if not _is_anthropic_model(model) or max_thinking_tokens <= 0:
|
||||
return None
|
||||
return {"reasoning": {"max_tokens": max_thinking_tokens}}
|
||||
|
||||
|
||||
class BaselineReasoningEmitter:
|
||||
"""Owns the reasoning block lifecycle for one streaming round.
|
||||
|
||||
Two concerns live here, both driven by the same state machine:
|
||||
|
||||
1. **Wire events.** The AI SDK v6 wire format pairs every
|
||||
``reasoning-start`` with a matching ``reasoning-end`` and treats
|
||||
reasoning / text / tool-use as distinct UI parts that must not
|
||||
interleave.
|
||||
2. **Session persistence.** ``ChatMessage(role="reasoning")`` rows in
|
||||
``session.messages`` are what
|
||||
``convertChatSessionToUiMessages.ts`` folds into the assistant
|
||||
bubble as ``{type: "reasoning"}`` UI parts on reload and on
|
||||
``useHydrateOnStreamEnd`` swaps. Without them the live-streamed
|
||||
reasoning parts get overwritten by the hydrated (reasoning-less)
|
||||
message list the moment the stream ends. Mirrors the SDK path's
|
||||
``acc.reasoning_response`` pattern so both routes render the same
|
||||
way on reload.
|
||||
|
||||
Pass ``session_messages`` to enable persistence; omit for pure
|
||||
wire-emission (tests, scratch callers). On first reasoning delta a
|
||||
fresh ``ChatMessage(role="reasoning")`` is appended and mutated
|
||||
in-place as further deltas arrive; :meth:`close` drops the reference
|
||||
but leaves the appended row intact.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_messages: list[ChatMessage] | None = None,
|
||||
) -> None:
|
||||
self._block_id: str = str(uuid.uuid4())
|
||||
self._open: bool = False
|
||||
self._session_messages = session_messages
|
||||
self._current_row: ChatMessage | None = None
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
return self._open
|
||||
|
||||
def on_delta(self, delta: ChoiceDelta) -> list[StreamBaseResponse]:
|
||||
"""Return events for the reasoning text carried by *delta*.
|
||||
|
||||
Empty list when the chunk carries no reasoning payload, so this is
|
||||
safe to call on every chunk without guarding at the call site.
|
||||
Persistence (when a session message list is attached) happens in
|
||||
lockstep with emission so the row's content stays equal to the
|
||||
concatenated deltas at every delta boundary.
|
||||
"""
|
||||
ext = OpenRouterDeltaExtension.from_delta(delta)
|
||||
text = ext.visible_text()
|
||||
if not text:
|
||||
return []
|
||||
events: list[StreamBaseResponse] = []
|
||||
if not self._open:
|
||||
events.append(StreamReasoningStart(id=self._block_id))
|
||||
self._open = True
|
||||
if self._session_messages is not None:
|
||||
self._current_row = ChatMessage(role="reasoning", content="")
|
||||
self._session_messages.append(self._current_row)
|
||||
events.append(StreamReasoningDelta(id=self._block_id, delta=text))
|
||||
if self._current_row is not None:
|
||||
self._current_row.content = (self._current_row.content or "") + text
|
||||
return events
|
||||
|
||||
def close(self) -> list[StreamBaseResponse]:
|
||||
"""Emit ``StreamReasoningEnd`` for the open block (if any) and rotate.
|
||||
|
||||
Idempotent — returns ``[]`` when no block is open. The id rotation
|
||||
guarantees the next reasoning block starts with a fresh id rather
|
||||
than reusing one already closed on the wire. The persisted row is
|
||||
not removed — it stays in ``session_messages`` as the durable
|
||||
record of what was reasoned.
|
||||
"""
|
||||
if not self._open:
|
||||
return []
|
||||
event = StreamReasoningEnd(id=self._block_id)
|
||||
self._open = False
|
||||
self._block_id = str(uuid.uuid4())
|
||||
self._current_row = None
|
||||
return [event]
|
||||
@@ -0,0 +1,281 @@
|
||||
"""Tests for the baseline reasoning extension module.
|
||||
|
||||
Covers the typed OpenRouter delta parser, the stateful emitter, and the
|
||||
``extra_body`` builder. The emitter is tested against real
|
||||
``ChoiceDelta`` pydantic instances so the ``model_extra`` plumbing the
|
||||
parser relies on is exercised end-to-end.
|
||||
"""
|
||||
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
|
||||
from backend.copilot.baseline.reasoning import (
|
||||
BaselineReasoningEmitter,
|
||||
OpenRouterDeltaExtension,
|
||||
ReasoningDetail,
|
||||
reasoning_extra_body,
|
||||
)
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.response_model import (
|
||||
StreamReasoningDelta,
|
||||
StreamReasoningEnd,
|
||||
StreamReasoningStart,
|
||||
)
|
||||
|
||||
|
||||
def _delta(**extra) -> ChoiceDelta:
|
||||
"""Build a ChoiceDelta with the given extension fields on ``model_extra``."""
|
||||
return ChoiceDelta.model_validate({"role": "assistant", **extra})
|
||||
|
||||
|
||||
class TestReasoningDetail:
|
||||
def test_visible_text_prefers_text(self):
|
||||
d = ReasoningDetail(type="reasoning.text", text="hi", summary="ignored")
|
||||
assert d.visible_text == "hi"
|
||||
|
||||
def test_visible_text_falls_back_to_summary(self):
|
||||
d = ReasoningDetail(type="reasoning.summary", summary="tldr")
|
||||
assert d.visible_text == "tldr"
|
||||
|
||||
def test_visible_text_empty_for_encrypted(self):
|
||||
d = ReasoningDetail(type="reasoning.encrypted")
|
||||
assert d.visible_text == ""
|
||||
|
||||
def test_unknown_fields_are_ignored(self):
|
||||
# OpenRouter may add new fields in future payloads — they shouldn't
|
||||
# cause validation errors.
|
||||
d = ReasoningDetail.model_validate(
|
||||
{"type": "reasoning.future", "text": "x", "signature": "opaque"}
|
||||
)
|
||||
assert d.text == "x"
|
||||
|
||||
def test_visible_text_empty_for_unknown_type(self):
|
||||
# Unknown types may carry provider metadata that must not render as
|
||||
# user-visible reasoning — regardless of whether a text/summary is
|
||||
# present. Only ``reasoning.text`` / ``reasoning.summary`` surface.
|
||||
d = ReasoningDetail(type="reasoning.future", text="leaked metadata")
|
||||
assert d.visible_text == ""
|
||||
|
||||
def test_visible_text_surfaces_text_when_type_missing(self):
|
||||
# Pre-``reasoning_details`` OpenRouter payloads omit ``type`` — treat
|
||||
# them as text so we don't regress the legacy structured shape.
|
||||
d = ReasoningDetail(text="plain")
|
||||
assert d.visible_text == "plain"
|
||||
|
||||
|
||||
class TestOpenRouterDeltaExtension:
|
||||
def test_from_delta_reads_model_extra(self):
|
||||
delta = _delta(reasoning="step one")
|
||||
ext = OpenRouterDeltaExtension.from_delta(delta)
|
||||
assert ext.reasoning == "step one"
|
||||
|
||||
def test_visible_text_legacy_string(self):
|
||||
ext = OpenRouterDeltaExtension(reasoning="plain text")
|
||||
assert ext.visible_text() == "plain text"
|
||||
|
||||
def test_visible_text_deepseek_alias(self):
|
||||
ext = OpenRouterDeltaExtension(reasoning_content="alt channel")
|
||||
assert ext.visible_text() == "alt channel"
|
||||
|
||||
def test_visible_text_structured_details_concat(self):
|
||||
ext = OpenRouterDeltaExtension(
|
||||
reasoning_details=[
|
||||
ReasoningDetail(type="reasoning.text", text="hello "),
|
||||
ReasoningDetail(type="reasoning.text", text="world"),
|
||||
]
|
||||
)
|
||||
assert ext.visible_text() == "hello world"
|
||||
|
||||
def test_visible_text_skips_encrypted(self):
|
||||
ext = OpenRouterDeltaExtension(
|
||||
reasoning_details=[
|
||||
ReasoningDetail(type="reasoning.encrypted"),
|
||||
ReasoningDetail(type="reasoning.text", text="visible"),
|
||||
]
|
||||
)
|
||||
assert ext.visible_text() == "visible"
|
||||
|
||||
def test_visible_text_empty_when_all_channels_blank(self):
|
||||
ext = OpenRouterDeltaExtension()
|
||||
assert ext.visible_text() == ""
|
||||
|
||||
def test_empty_delta_produces_empty_extension(self):
|
||||
ext = OpenRouterDeltaExtension.from_delta(_delta())
|
||||
assert ext.reasoning is None
|
||||
assert ext.reasoning_content is None
|
||||
assert ext.reasoning_details == []
|
||||
|
||||
def test_malformed_reasoning_payload_logged_and_swallowed(self, caplog):
|
||||
# A malformed payload (e.g. reasoning_details shipped as a string
|
||||
# rather than a list) must not abort the stream — log it and
|
||||
# return an empty extension so valid text/tool events keep flowing.
|
||||
# A plain mock is used here because ``from_delta`` only reads
|
||||
# ``delta.model_extra`` — avoids reaching into pydantic internals
|
||||
# (``__pydantic_extra__``) that could be renamed across versions.
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
delta = MagicMock(spec=ChoiceDelta)
|
||||
delta.model_extra = {"reasoning_details": "not a list"}
|
||||
with caplog.at_level("WARNING"):
|
||||
ext = OpenRouterDeltaExtension.from_delta(delta)
|
||||
assert ext.reasoning_details == []
|
||||
assert ext.visible_text() == ""
|
||||
assert any("malformed" in r.message.lower() for r in caplog.records)
|
||||
|
||||
def test_unknown_typed_entry_with_text_is_not_surfaced(self):
|
||||
# Regression: the legacy extractor emitted any entry with a
|
||||
# ``text`` or ``summary`` field. The typed parser now filters on
|
||||
# the recognised types so future provider metadata can't leak
|
||||
# into the reasoning collapse.
|
||||
ext = OpenRouterDeltaExtension(
|
||||
reasoning_details=[
|
||||
ReasoningDetail(type="reasoning.future", text="provider metadata"),
|
||||
ReasoningDetail(type="reasoning.text", text="real"),
|
||||
]
|
||||
)
|
||||
assert ext.visible_text() == "real"
|
||||
|
||||
|
||||
class TestReasoningExtraBody:
|
||||
def test_anthropic_route_returns_fragment(self):
|
||||
assert reasoning_extra_body("anthropic/claude-sonnet-4-6", 4096) == {
|
||||
"reasoning": {"max_tokens": 4096}
|
||||
}
|
||||
|
||||
def test_direct_claude_model_id_still_matches(self):
|
||||
assert reasoning_extra_body("claude-3-5-sonnet-20241022", 2048) == {
|
||||
"reasoning": {"max_tokens": 2048}
|
||||
}
|
||||
|
||||
def test_non_anthropic_route_returns_none(self):
|
||||
assert reasoning_extra_body("openai/gpt-4o", 4096) is None
|
||||
assert reasoning_extra_body("google/gemini-2.5-pro", 4096) is None
|
||||
|
||||
def test_zero_max_tokens_kill_switch(self):
|
||||
# Operator kill switch: ``max_thinking_tokens <= 0`` disables the
|
||||
# ``reasoning`` extra_body fragment even on an Anthropic route.
|
||||
# Lets us silence reasoning without dropping the SDK path's budget.
|
||||
assert reasoning_extra_body("anthropic/claude-sonnet-4-6", 0) is None
|
||||
assert reasoning_extra_body("anthropic/claude-sonnet-4-6", -1) is None
|
||||
|
||||
|
||||
class TestBaselineReasoningEmitter:
|
||||
def test_first_text_delta_emits_start_then_delta(self):
|
||||
emitter = BaselineReasoningEmitter()
|
||||
events = emitter.on_delta(_delta(reasoning="thinking"))
|
||||
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], StreamReasoningStart)
|
||||
assert isinstance(events[1], StreamReasoningDelta)
|
||||
assert events[0].id == events[1].id
|
||||
assert events[1].delta == "thinking"
|
||||
assert emitter.is_open is True
|
||||
|
||||
def test_subsequent_deltas_reuse_block_id_without_new_start(self):
|
||||
emitter = BaselineReasoningEmitter()
|
||||
first = emitter.on_delta(_delta(reasoning="a"))
|
||||
second = emitter.on_delta(_delta(reasoning="b"))
|
||||
|
||||
assert any(isinstance(e, StreamReasoningStart) for e in first)
|
||||
assert all(not isinstance(e, StreamReasoningStart) for e in second)
|
||||
assert len(second) == 1
|
||||
assert isinstance(second[0], StreamReasoningDelta)
|
||||
assert first[0].id == second[0].id
|
||||
|
||||
def test_empty_delta_emits_nothing(self):
|
||||
emitter = BaselineReasoningEmitter()
|
||||
assert emitter.on_delta(_delta(content="hello")) == []
|
||||
assert emitter.is_open is False
|
||||
|
||||
def test_close_emits_end_and_rotates_id(self):
|
||||
emitter = BaselineReasoningEmitter()
|
||||
# Capture the block id from the wire event rather than reaching
|
||||
# into emitter internals — the id on the emitted Start/Delta is
|
||||
# what the frontend actually receives.
|
||||
start_events = emitter.on_delta(_delta(reasoning="x"))
|
||||
first_id = start_events[0].id
|
||||
|
||||
events = emitter.close()
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], StreamReasoningEnd)
|
||||
assert events[0].id == first_id
|
||||
assert emitter.is_open is False
|
||||
# Next reasoning uses a fresh id.
|
||||
new_events = emitter.on_delta(_delta(reasoning="y"))
|
||||
assert isinstance(new_events[0], StreamReasoningStart)
|
||||
assert new_events[0].id != first_id
|
||||
|
||||
def test_close_is_idempotent(self):
|
||||
emitter = BaselineReasoningEmitter()
|
||||
assert emitter.close() == []
|
||||
emitter.on_delta(_delta(reasoning="x"))
|
||||
assert len(emitter.close()) == 1
|
||||
assert emitter.close() == []
|
||||
|
||||
def test_structured_details_round_trip(self):
|
||||
emitter = BaselineReasoningEmitter()
|
||||
events = emitter.on_delta(
|
||||
_delta(
|
||||
reasoning_details=[
|
||||
{"type": "reasoning.text", "text": "plan: "},
|
||||
{"type": "reasoning.summary", "summary": "do the thing"},
|
||||
]
|
||||
)
|
||||
)
|
||||
deltas = [e for e in events if isinstance(e, StreamReasoningDelta)]
|
||||
assert len(deltas) == 1
|
||||
assert deltas[0].delta == "plan: do the thing"
|
||||
|
||||
|
||||
class TestReasoningPersistence:
|
||||
"""The persistence contract: without ``role="reasoning"`` rows in
|
||||
session.messages, useHydrateOnStreamEnd overwrites the live-streamed
|
||||
reasoning parts and the Reasoning collapse vanishes. Every delta
|
||||
must be reflected in the persisted row the moment it's emitted."""
|
||||
|
||||
def test_session_row_appended_on_first_delta(self):
|
||||
session: list[ChatMessage] = []
|
||||
emitter = BaselineReasoningEmitter(session)
|
||||
|
||||
assert session == []
|
||||
emitter.on_delta(_delta(reasoning="hi"))
|
||||
assert len(session) == 1
|
||||
assert session[0].role == "reasoning"
|
||||
assert session[0].content == "hi"
|
||||
|
||||
def test_subsequent_deltas_mutate_same_row(self):
|
||||
session: list[ChatMessage] = []
|
||||
emitter = BaselineReasoningEmitter(session)
|
||||
|
||||
emitter.on_delta(_delta(reasoning="part one "))
|
||||
emitter.on_delta(_delta(reasoning="part two"))
|
||||
|
||||
assert len(session) == 1
|
||||
assert session[0].content == "part one part two"
|
||||
|
||||
def test_close_keeps_row_in_session(self):
|
||||
session: list[ChatMessage] = []
|
||||
emitter = BaselineReasoningEmitter(session)
|
||||
|
||||
emitter.on_delta(_delta(reasoning="thought"))
|
||||
emitter.close()
|
||||
|
||||
assert len(session) == 1
|
||||
assert session[0].content == "thought"
|
||||
|
||||
def test_second_reasoning_block_appends_new_row(self):
|
||||
session: list[ChatMessage] = []
|
||||
emitter = BaselineReasoningEmitter(session)
|
||||
|
||||
emitter.on_delta(_delta(reasoning="first"))
|
||||
emitter.close()
|
||||
emitter.on_delta(_delta(reasoning="second"))
|
||||
|
||||
assert len(session) == 2
|
||||
assert [m.content for m in session] == ["first", "second"]
|
||||
|
||||
def test_no_session_means_no_persistence(self):
|
||||
"""Emitter without attached session list emits wire events only."""
|
||||
emitter = BaselineReasoningEmitter()
|
||||
events = emitter.on_delta(_delta(reasoning="pure wire"))
|
||||
assert len(events) == 2 # start + delta, no crash
|
||||
# Nothing else to assert — just proves None session is supported.
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,7 @@
|
||||
"""Integration tests for baseline transcript flow.
|
||||
|
||||
Exercises the real helpers in ``baseline/service.py`` that download,
|
||||
validate, load, append to, backfill, and upload the transcript.
|
||||
Exercises the real helpers in ``baseline/service.py`` that restore,
|
||||
validate, load, append to, backfill, and upload the CLI session.
|
||||
Storage is mocked via ``download_transcript`` / ``upload_transcript``
|
||||
patches; no network access is required.
|
||||
"""
|
||||
@@ -12,13 +12,14 @@ from unittest.mock import AsyncMock, patch
|
||||
import pytest
|
||||
|
||||
from backend.copilot.baseline.service import (
|
||||
_append_gap_to_builder,
|
||||
_load_prior_transcript,
|
||||
_record_turn_to_transcript,
|
||||
_resolve_baseline_model,
|
||||
_upload_final_transcript,
|
||||
is_transcript_stale,
|
||||
should_upload_transcript,
|
||||
)
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.service import config
|
||||
from backend.copilot.transcript import (
|
||||
STOP_REASON_END_TURN,
|
||||
@@ -54,106 +55,128 @@ def _make_transcript_content(*roles: str) -> str:
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def _make_session_messages(*roles: str) -> list[ChatMessage]:
|
||||
"""Build a list of ChatMessage objects matching the given roles."""
|
||||
return [
|
||||
ChatMessage(role=r, content=f"{r} message {i}") for i, r in enumerate(roles)
|
||||
]
|
||||
|
||||
|
||||
class TestResolveBaselineModel:
|
||||
"""Model selection honours the per-request mode."""
|
||||
"""Baseline model resolution honours the per-request tier toggle."""
|
||||
|
||||
def test_fast_mode_selects_fast_model(self):
|
||||
assert _resolve_baseline_model("fast") == config.fast_model
|
||||
def test_advanced_tier_selects_advanced_model(self):
|
||||
assert _resolve_baseline_model("advanced") == config.advanced_model
|
||||
|
||||
def test_extended_thinking_selects_default_model(self):
|
||||
assert _resolve_baseline_model("extended_thinking") == config.model
|
||||
def test_standard_tier_selects_default_model(self):
|
||||
assert _resolve_baseline_model("standard") == config.model
|
||||
|
||||
def test_none_mode_selects_default_model(self):
|
||||
"""Critical: baseline users without a mode MUST keep the default (opus)."""
|
||||
def test_none_tier_selects_default_model(self):
|
||||
"""Baseline users without a tier MUST keep the default (standard)."""
|
||||
assert _resolve_baseline_model(None) == config.model
|
||||
|
||||
def test_default_and_fast_models_same(self):
|
||||
"""SDK defaults currently keep standard and fast on Sonnet 4.6."""
|
||||
assert config.model == config.fast_model
|
||||
def test_standard_and_advanced_models_differ(self):
|
||||
"""Advanced tier defaults to a different (Opus) model than standard."""
|
||||
assert config.model != config.advanced_model
|
||||
|
||||
|
||||
class TestLoadPriorTranscript:
|
||||
"""``_load_prior_transcript`` wraps the download + validate + load flow."""
|
||||
"""``_load_prior_transcript`` wraps the CLI session restore + validate + load flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loads_fresh_transcript(self):
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=content, message_count=2)
|
||||
restore = TranscriptDownload(
|
||||
content=content.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
session_messages=_make_session_messages("user", "assistant", "user"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is True
|
||||
assert dl is not None
|
||||
assert dl.message_count == 2
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_stale_transcript(self):
|
||||
"""msg_count strictly less than session-1 is treated as stale."""
|
||||
async def test_fills_gap_when_transcript_is_behind(self):
|
||||
"""When transcript covers fewer messages than session, gap is filled from DB."""
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_transcript_content("user", "assistant")
|
||||
# session has 6 messages, transcript only covers 2 → stale.
|
||||
download = TranscriptDownload(content=content, message_count=2)
|
||||
# transcript covers 2 messages, session has 4 (plus current user turn = 5)
|
||||
restore = TranscriptDownload(
|
||||
content=content.encode("utf-8"), message_count=2, mode="baseline"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=6,
|
||||
session_messages=_make_session_messages(
|
||||
"user", "assistant", "user", "assistant", "user"
|
||||
),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
assert covers is True
|
||||
assert dl is not None
|
||||
# 2 from transcript + 2 gap messages (user+assistant at positions 2,3)
|
||||
assert builder.entry_count == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_transcript_returns_false(self):
|
||||
async def test_missing_transcript_allows_upload(self):
|
||||
"""Nothing in GCS → upload is safe; the turn writes the first snapshot."""
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
upload_safe, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
session_messages=_make_session_messages("user", "assistant"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert upload_safe is True
|
||||
assert dl is None
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_transcript_returns_false(self):
|
||||
async def test_invalid_transcript_allows_upload(self):
|
||||
"""Corrupt file in GCS → overwriting with a valid one is better."""
|
||||
builder = TranscriptBuilder()
|
||||
download = TranscriptDownload(
|
||||
content='{"type":"progress","uuid":"a"}\n',
|
||||
restore = TranscriptDownload(
|
||||
content=b'{"type":"progress","uuid":"a"}\n',
|
||||
message_count=1,
|
||||
mode="sdk",
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
upload_safe, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
session_messages=_make_session_messages("user", "assistant"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert upload_safe is True
|
||||
assert dl is None
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -163,36 +186,39 @@ class TestLoadPriorTranscript:
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(side_effect=RuntimeError("boom")),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
session_messages=_make_session_messages("user", "assistant"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert dl is None
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_message_count_not_stale(self):
|
||||
"""When msg_count is 0 (unknown), staleness check is skipped."""
|
||||
"""When msg_count is 0 (unknown), gap detection is skipped."""
|
||||
builder = TranscriptBuilder()
|
||||
download = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant"),
|
||||
restore = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant").encode("utf-8"),
|
||||
message_count=0,
|
||||
mode="sdk",
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=20,
|
||||
session_messages=_make_session_messages(*["user"] * 20),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is True
|
||||
assert dl is not None
|
||||
assert builder.entry_count == 2
|
||||
|
||||
|
||||
@@ -227,7 +253,7 @@ class TestUploadFinalTranscript:
|
||||
assert call_kwargs["user_id"] == "user-1"
|
||||
assert call_kwargs["session_id"] == "session-1"
|
||||
assert call_kwargs["message_count"] == 2
|
||||
assert "hello" in call_kwargs["content"]
|
||||
assert b"hello" in call_kwargs["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_upload_when_builder_empty(self):
|
||||
@@ -374,17 +400,19 @@ class TestRoundTrip:
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_round_trip(self):
|
||||
prior = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=prior, message_count=2)
|
||||
restore = TranscriptDownload(
|
||||
content=prior.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, _ = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
session_messages=_make_session_messages("user", "assistant", "user"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
assert covers is True
|
||||
@@ -424,11 +452,11 @@ class TestRoundTrip:
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
uploaded = upload_mock.await_args.kwargs["content"]
|
||||
assert "new question" in uploaded
|
||||
assert "new answer" in uploaded
|
||||
assert b"new question" in uploaded
|
||||
assert b"new answer" in uploaded
|
||||
# Original content preserved in the round trip.
|
||||
assert "user message 0" in uploaded
|
||||
assert "assistant message 1" in uploaded
|
||||
assert b"user message 0" in uploaded
|
||||
assert b"assistant message 1" in uploaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_append_guard(self):
|
||||
@@ -459,36 +487,6 @@ class TestRoundTrip:
|
||||
assert builder.entry_count == initial_count
|
||||
|
||||
|
||||
class TestIsTranscriptStale:
|
||||
"""``is_transcript_stale`` gates prior-transcript loading."""
|
||||
|
||||
def test_none_download_is_not_stale(self):
|
||||
assert is_transcript_stale(None, session_msg_count=5) is False
|
||||
|
||||
def test_zero_message_count_is_not_stale(self):
|
||||
"""Legacy transcripts without msg_count tracking must remain usable."""
|
||||
dl = TranscriptDownload(content="", message_count=0)
|
||||
assert is_transcript_stale(dl, session_msg_count=20) is False
|
||||
|
||||
def test_stale_when_covers_less_than_prefix(self):
|
||||
dl = TranscriptDownload(content="", message_count=2)
|
||||
# session has 6 messages; transcript must cover at least 5 (6-1).
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is True
|
||||
|
||||
def test_fresh_when_covers_full_prefix(self):
|
||||
dl = TranscriptDownload(content="", message_count=5)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
def test_fresh_when_exceeds_prefix(self):
|
||||
"""Race: transcript ahead of session count is still acceptable."""
|
||||
dl = TranscriptDownload(content="", message_count=10)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
def test_boundary_equal_to_prefix_minus_one(self):
|
||||
dl = TranscriptDownload(content="", message_count=5)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
|
||||
class TestShouldUploadTranscript:
|
||||
"""``should_upload_transcript`` gates the final upload."""
|
||||
|
||||
@@ -510,7 +508,7 @@ class TestShouldUploadTranscript:
|
||||
|
||||
|
||||
class TestTranscriptLifecycle:
|
||||
"""End-to-end: download → validate → build → upload.
|
||||
"""End-to-end: restore → validate → build → upload.
|
||||
|
||||
Simulates the full transcript lifecycle inside
|
||||
``stream_chat_completion_baseline`` by mocking the storage layer and
|
||||
@@ -519,27 +517,29 @@ class TestTranscriptLifecycle:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_lifecycle_happy_path(self):
|
||||
"""Fresh download, append a turn, upload covers the session."""
|
||||
"""Fresh restore, append a turn, upload covers the session."""
|
||||
builder = TranscriptBuilder()
|
||||
prior = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=prior, message_count=2)
|
||||
restore = TranscriptDownload(
|
||||
content=prior.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
# --- 1. Download & load prior transcript ---
|
||||
covers = await _load_prior_transcript(
|
||||
# --- 1. Restore & load prior session ---
|
||||
covers, _ = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
session_messages=_make_session_messages("user", "assistant", "user"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
assert covers is True
|
||||
@@ -559,10 +559,7 @@ class TestTranscriptLifecycle:
|
||||
|
||||
# --- 3. Gate + upload ---
|
||||
assert (
|
||||
should_upload_transcript(
|
||||
user_id="user-1", transcript_covers_prefix=covers
|
||||
)
|
||||
is True
|
||||
should_upload_transcript(user_id="user-1", upload_safe=covers) is True
|
||||
)
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
@@ -574,20 +571,21 @@ class TestTranscriptLifecycle:
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
uploaded = upload_mock.await_args.kwargs["content"]
|
||||
assert "follow-up question" in uploaded
|
||||
assert "follow-up answer" in uploaded
|
||||
assert b"follow-up question" in uploaded
|
||||
assert b"follow-up answer" in uploaded
|
||||
# Original prior-turn content preserved.
|
||||
assert "user message 0" in uploaded
|
||||
assert "assistant message 1" in uploaded
|
||||
assert b"user message 0" in uploaded
|
||||
assert b"assistant message 1" in uploaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_stale_download_suppresses_upload(self):
|
||||
"""Stale download → covers=False → upload must be skipped."""
|
||||
async def test_lifecycle_stale_download_fills_gap(self):
|
||||
"""When transcript covers fewer messages, gap is filled rather than rejected."""
|
||||
builder = TranscriptBuilder()
|
||||
# session has 10 msgs but stored transcript only covers 2 → stale.
|
||||
# session has 5 msgs but stored transcript only covers 2 → gap filled.
|
||||
stale = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant"),
|
||||
content=_make_transcript_content("user", "assistant").encode("utf-8"),
|
||||
message_count=2,
|
||||
mode="baseline",
|
||||
)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
@@ -601,20 +599,18 @@ class TestTranscriptLifecycle:
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, _ = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=10,
|
||||
session_messages=_make_session_messages(
|
||||
"user", "assistant", "user", "assistant", "user"
|
||||
),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
# The caller's gate mirrors the production path.
|
||||
assert (
|
||||
should_upload_transcript(user_id="user-1", transcript_covers_prefix=covers)
|
||||
is False
|
||||
)
|
||||
upload_mock.assert_not_awaited()
|
||||
assert covers is True
|
||||
# Gap was filled: 2 from transcript + 2 gap messages
|
||||
assert builder.entry_count == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_anonymous_user_skips_upload(self):
|
||||
@@ -627,15 +623,11 @@ class TestTranscriptLifecycle:
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
assert (
|
||||
should_upload_transcript(user_id=None, transcript_covers_prefix=True)
|
||||
is False
|
||||
)
|
||||
assert should_upload_transcript(user_id=None, upload_safe=True) is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_missing_download_still_uploads_new_content(self):
|
||||
"""No prior transcript → covers defaults to True in the service,
|
||||
new turn should upload cleanly."""
|
||||
"""No prior session → upload is safe; the turn writes the first snapshot."""
|
||||
builder = TranscriptBuilder()
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
@@ -648,20 +640,117 @@ class TestTranscriptLifecycle:
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
upload_safe, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=1,
|
||||
session_messages=_make_session_messages("user"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
# No download: covers is False, so the production path would
|
||||
# skip upload. This protects against overwriting a future
|
||||
# more-complete transcript with a single-turn snapshot.
|
||||
assert covers is False
|
||||
# Nothing in GCS → upload is safe so the first baseline turn
|
||||
# can write the initial transcript snapshot.
|
||||
assert upload_safe is True
|
||||
assert dl is None
|
||||
assert (
|
||||
should_upload_transcript(
|
||||
user_id="user-1", transcript_covers_prefix=covers
|
||||
)
|
||||
is False
|
||||
should_upload_transcript(user_id="user-1", upload_safe=upload_safe)
|
||||
is True
|
||||
)
|
||||
upload_mock.assert_not_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _append_gap_to_builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAppendGapToBuilder:
|
||||
"""``_append_gap_to_builder`` converts ChatMessage objects to TranscriptBuilder entries."""
|
||||
|
||||
def test_user_message_appended(self):
|
||||
builder = TranscriptBuilder()
|
||||
msgs = [ChatMessage(role="user", content="hello")]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 1
|
||||
assert builder.last_entry_type == "user"
|
||||
|
||||
def test_assistant_text_message_appended(self):
|
||||
builder = TranscriptBuilder()
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="q"),
|
||||
ChatMessage(role="assistant", content="answer"),
|
||||
]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
assert "answer" in builder.to_jsonl()
|
||||
|
||||
def test_assistant_with_tool_calls_appended(self):
|
||||
"""Assistant tool_calls are recorded as tool_use blocks in the transcript."""
|
||||
builder = TranscriptBuilder()
|
||||
tool_call = {
|
||||
"id": "tc-1",
|
||||
"type": "function",
|
||||
"function": {"name": "my_tool", "arguments": '{"key":"val"}'},
|
||||
}
|
||||
msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 1
|
||||
jsonl = builder.to_jsonl()
|
||||
assert "tool_use" in jsonl
|
||||
assert "my_tool" in jsonl
|
||||
assert "tc-1" in jsonl
|
||||
|
||||
def test_assistant_invalid_json_args_uses_empty_dict(self):
|
||||
"""Malformed JSON in tool_call arguments falls back to {}."""
|
||||
builder = TranscriptBuilder()
|
||||
tool_call = {
|
||||
"id": "tc-bad",
|
||||
"type": "function",
|
||||
"function": {"name": "bad_tool", "arguments": "not-json"},
|
||||
}
|
||||
msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 1
|
||||
jsonl = builder.to_jsonl()
|
||||
assert '"input":{}' in jsonl
|
||||
|
||||
def test_assistant_empty_content_and_no_tools_uses_fallback(self):
|
||||
"""Assistant with no content and no tool_calls gets a fallback empty text block."""
|
||||
builder = TranscriptBuilder()
|
||||
msgs = [ChatMessage(role="assistant", content=None)]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 1
|
||||
jsonl = builder.to_jsonl()
|
||||
assert "text" in jsonl
|
||||
|
||||
def test_tool_role_with_tool_call_id_appended(self):
|
||||
"""Tool result messages are appended when tool_call_id is set."""
|
||||
builder = TranscriptBuilder()
|
||||
# Need a preceding assistant tool_use entry
|
||||
builder.append_user("use tool")
|
||||
builder.append_assistant(
|
||||
content_blocks=[
|
||||
{"type": "tool_use", "id": "tc-1", "name": "my_tool", "input": {}}
|
||||
]
|
||||
)
|
||||
msgs = [ChatMessage(role="tool", tool_call_id="tc-1", content="result")]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 3
|
||||
assert "tool_result" in builder.to_jsonl()
|
||||
|
||||
def test_tool_role_without_tool_call_id_skipped(self):
|
||||
"""Tool messages without tool_call_id are silently skipped."""
|
||||
builder = TranscriptBuilder()
|
||||
msgs = [ChatMessage(role="tool", tool_call_id=None, content="orphan")]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 0
|
||||
|
||||
def test_tool_call_missing_function_key_uses_unknown_name(self):
|
||||
"""A tool_call dict with no 'function' key uses 'unknown' as the tool name."""
|
||||
builder = TranscriptBuilder()
|
||||
# Tool call dict exists but 'function' sub-dict is missing entirely
|
||||
msgs = [
|
||||
ChatMessage(role="assistant", content=None, tool_calls=[{"id": "tc-x"}])
|
||||
]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 1
|
||||
jsonl = builder.to_jsonl()
|
||||
assert "unknown" in jsonl
|
||||
|
||||
@@ -17,8 +17,8 @@ from backend.util.clients import OPENROUTER_BASE_URL
|
||||
CopilotMode = Literal["fast", "extended_thinking"]
|
||||
|
||||
# Per-request model tier set by the frontend model toggle.
|
||||
# 'standard' uses the global config default (currently Sonnet).
|
||||
# 'advanced' forces the highest-capability model (currently Opus).
|
||||
# 'standard' uses ``ChatConfig.model`` (Sonnet by default).
|
||||
# 'advanced' uses ``ChatConfig.advanced_model`` (Opus by default).
|
||||
# None means no preference — falls through to LD per-user targeting, then config.
|
||||
# Using tier names instead of model names keeps the contract model-agnostic.
|
||||
CopilotLlmModel = Literal["standard", "advanced"]
|
||||
@@ -27,16 +27,21 @@ CopilotLlmModel = Literal["standard", "advanced"]
|
||||
class ChatConfig(BaseSettings):
|
||||
"""Configuration for the chat system."""
|
||||
|
||||
# OpenAI API Configuration
|
||||
# Chat model tiers — applied orthogonally to the path (fast=baseline vs
|
||||
# extended_thinking=SDK). The "fast" vs "extended_thinking" toggle picks
|
||||
# which code path runs (no reasoning / heavy SDK); "standard" vs
|
||||
# "advanced" picks the model inside that path.
|
||||
model: str = Field(
|
||||
default="anthropic/claude-sonnet-4-6",
|
||||
description="Default model for extended thinking mode. "
|
||||
"Uses Sonnet 4.6 as the balanced default. "
|
||||
"Override via CHAT_MODEL env var if you want a different default.",
|
||||
description="Model used for the 'standard' tier (Sonnet by default). "
|
||||
"Applies to both baseline (fast) and SDK (extended thinking) paths. "
|
||||
"Override via CHAT_MODEL env var.",
|
||||
)
|
||||
fast_model: str = Field(
|
||||
default="anthropic/claude-sonnet-4-6",
|
||||
description="Model for fast mode (baseline path). Should be faster/cheaper than the default model.",
|
||||
advanced_model: str = Field(
|
||||
default="anthropic/claude-opus-4-7",
|
||||
description="Model used for the 'advanced' tier (Opus by default). "
|
||||
"Applies to both baseline (fast) and SDK (extended thinking) paths. "
|
||||
"Override via CHAT_ADVANCED_MODEL env var.",
|
||||
)
|
||||
title_model: str = Field(
|
||||
default="openai/gpt-4o-mini",
|
||||
@@ -96,25 +101,31 @@ class ChatConfig(BaseSettings):
|
||||
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
|
||||
)
|
||||
|
||||
# Rate limiting — token-based limits per day and per week.
|
||||
# Per-turn token cost varies with context size: ~10-15K for early turns,
|
||||
# ~30-50K mid-session, up to ~100K pre-compaction. Average across a
|
||||
# session with compaction cycles is ~25-35K tokens/turn, so 2.5M daily
|
||||
# allows ~70-100 turns/day.
|
||||
# Rate limiting — cost-based limits per day and per week, stored in
|
||||
# microdollars (1 USD = 1_000_000). The counter tracks the real
|
||||
# generation cost reported by the provider (OpenRouter ``usage.cost``
|
||||
# or Claude Agent SDK ``total_cost_usd``), so cache discounts and
|
||||
# cross-model price differences are already reflected — no token
|
||||
# weighting or model multiplier is applied on top.
|
||||
# Checked at the HTTP layer (routes.py) before each turn.
|
||||
#
|
||||
# These are base limits for the FREE tier. Higher tiers (PRO, BUSINESS,
|
||||
# These are base limits for the FREE tier. Higher tiers (PRO, BUSINESS,
|
||||
# ENTERPRISE) multiply these by their tier multiplier (see
|
||||
# rate_limit.TIER_MULTIPLIERS). User tier is stored in the
|
||||
# rate_limit.TIER_MULTIPLIERS). User tier is stored in the
|
||||
# User.subscriptionTier DB column and resolved inside
|
||||
# get_global_rate_limits().
|
||||
daily_token_limit: int = Field(
|
||||
default=2_500_000,
|
||||
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
|
||||
#
|
||||
# These defaults act as the ceiling when LaunchDarkly is unreachable;
|
||||
# the live per-tier values come from the COPILOT_*_COST_LIMIT flags.
|
||||
daily_cost_limit_microdollars: int = Field(
|
||||
default=1_000_000,
|
||||
description="Max cost per day in microdollars, resets at midnight UTC "
|
||||
"(0 = unlimited).",
|
||||
)
|
||||
weekly_token_limit: int = Field(
|
||||
default=12_500_000,
|
||||
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
|
||||
weekly_cost_limit_microdollars: int = Field(
|
||||
default=5_000_000,
|
||||
description="Max cost per week in microdollars, resets Monday 00:00 UTC "
|
||||
"(0 = unlimited).",
|
||||
)
|
||||
|
||||
# Cost (in credits / cents) to reset the daily rate limit using credits.
|
||||
@@ -183,9 +194,11 @@ class ChatConfig(BaseSettings):
|
||||
default=8192,
|
||||
ge=1024,
|
||||
le=128000,
|
||||
description="Maximum thinking/reasoning tokens per LLM call. "
|
||||
"Extended thinking on Opus can generate 50k+ tokens at $75/M — "
|
||||
"capping this is the single biggest cost lever. "
|
||||
description="Maximum thinking/reasoning tokens per LLM call. Applies "
|
||||
"to both the Claude Agent SDK path (as ``max_thinking_tokens``) and "
|
||||
"the baseline OpenRouter path (as ``extra_body.reasoning.max_tokens`` "
|
||||
"on Anthropic routes). Extended thinking on Opus can generate 50k+ "
|
||||
"tokens at $75/M — capping this is the single biggest cost lever. "
|
||||
"8192 is sufficient for most tasks; increase for complex reasoning.",
|
||||
)
|
||||
claude_agent_thinking_effort: Literal["low", "medium", "high", "max"] | None = (
|
||||
@@ -214,6 +227,18 @@ class ChatConfig(BaseSettings):
|
||||
"from the prefix. Set to False to fall back to passing the system "
|
||||
"prompt as a raw string.",
|
||||
)
|
||||
baseline_prompt_cache_ttl: str = Field(
|
||||
default="1h",
|
||||
description="TTL for the ephemeral prompt-cache markers on the baseline "
|
||||
"OpenRouter path. Anthropic supports only `5m` (default, 1.25x input "
|
||||
"price for the write) or `1h` (2x input price for the write). 1h is "
|
||||
"strictly cheaper overall when the static prefix gets >7 reads per "
|
||||
"write-window; since the system prompt + tools array is identical "
|
||||
"across all users in our workspace, 1h is the default so cross-user "
|
||||
"reads amortise the higher write cost. Anthropic has no longer "
|
||||
"(24h, permanent) TTL option — see "
|
||||
"https://platform.claude.com/docs/en/build-with-claude/prompt-caching.",
|
||||
)
|
||||
claude_agent_cli_path: str | None = Field(
|
||||
default=None,
|
||||
description="Optional explicit path to a Claude Code CLI binary. "
|
||||
|
||||
@@ -9,6 +9,11 @@ COPILOT_RETRYABLE_ERROR_PREFIX = (
|
||||
)
|
||||
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
|
||||
|
||||
# Canonical marker appended as an assistant ChatMessage when the SDK stream
|
||||
# ends without a ResultMessage (user hit Stop). Checked by exact equality
|
||||
# at turn start so the next turn's --resume transcript doesn't carry it.
|
||||
STOPPED_BY_USER_MARKER = f"{COPILOT_SYSTEM_PREFIX} Execution stopped by user"
|
||||
|
||||
# Prefix for all synthetic IDs generated by CoPilot block execution.
|
||||
# Used to distinguish CoPilot-generated records from real graph execution records
|
||||
# in PendingHumanReview and other tables.
|
||||
@@ -27,6 +32,24 @@ COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context li
|
||||
COMPACTION_TOOL_NAME = "context_compaction"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool / stream timing budget
|
||||
# ---------------------------------------------------------------------------
|
||||
# Max seconds any single MCP tool call may block the stream before returning
|
||||
# a "still running" handle. Shared by run_agent (wait_for_result),
|
||||
# view_agent_output (wait_if_running), run_sub_session (wait_for_result),
|
||||
# get_sub_session_result (wait_if_running), and run_block (hard cap).
|
||||
#
|
||||
# Chosen so the stream idle timeout (2× this) always has headroom — a tool
|
||||
# that returns right at the cap can't race the idle watchdog.
|
||||
MAX_TOOL_WAIT_SECONDS = 5 * 60 # 5 minutes
|
||||
|
||||
# Idle-stream watchdog: abort the SDK stream if no meaningful event arrives
|
||||
# for this long. Derived from MAX_TOOL_WAIT_SECONDS so the invariant
|
||||
# "no tool blocks >= idle_timeout" holds by construction.
|
||||
STREAM_IDLE_TIMEOUT_SECONDS = MAX_TOOL_WAIT_SECONDS * 2 # 10 minutes
|
||||
|
||||
|
||||
def is_copilot_synthetic_id(id_value: str) -> bool:
|
||||
"""Check if an ID is a CoPilot synthetic ID (not from a real graph execution)."""
|
||||
return id_value.startswith(COPILOT_SYNTHETIC_ID_PREFIX)
|
||||
|
||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
# Allowed base directory for the Read tool. Public so service.py can use it
|
||||
# for sweep operations without depending on a private implementation detail.
|
||||
# Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's
|
||||
# _projects_base() function.
|
||||
# projects_base() function.
|
||||
_config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
SDK_PROJECTS_DIR = os.path.realpath(os.path.join(_config_dir, "projects"))
|
||||
|
||||
|
||||
@@ -10,9 +10,11 @@ from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.types import (
|
||||
ChatMessageCreateInput,
|
||||
ChatMessageWhereInput,
|
||||
ChatSessionCreateInput,
|
||||
ChatSessionUpdateInput,
|
||||
ChatSessionWhereInput,
|
||||
FindManyChatMessageArgsFromChatSession,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -30,6 +32,8 @@ from .model import get_chat_session as get_chat_session_cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BOUNDARY_SCAN_LIMIT = 10
|
||||
|
||||
|
||||
class PaginatedMessages(BaseModel):
|
||||
"""Result of a paginated message query."""
|
||||
@@ -69,12 +73,10 @@ async def get_chat_messages_paginated(
|
||||
in parallel with the message query. Returns ``None`` when the session
|
||||
is not found or does not belong to the user.
|
||||
|
||||
Args:
|
||||
session_id: The chat session ID.
|
||||
limit: Max messages to return.
|
||||
before_sequence: Cursor — return messages with sequence < this value.
|
||||
user_id: If provided, filters via ``Session.userId`` so only the
|
||||
session owner's messages are returned (acts as an ownership guard).
|
||||
After fetching, a visibility guarantee ensures the page contains at least
|
||||
one user or assistant message. If the entire page is tool messages (which
|
||||
are hidden in the UI), it expands backward until a visible message is found
|
||||
so the chat never appears blank.
|
||||
"""
|
||||
# Build session-existence / ownership check
|
||||
session_where: ChatSessionWhereInput = {"id": session_id}
|
||||
@@ -82,7 +84,7 @@ async def get_chat_messages_paginated(
|
||||
session_where["userId"] = user_id
|
||||
|
||||
# Build message include — fetch paginated messages in the same query
|
||||
msg_include: dict[str, Any] = {
|
||||
msg_include: FindManyChatMessageArgsFromChatSession = {
|
||||
"order_by": {"sequence": "desc"},
|
||||
"take": limit + 1,
|
||||
}
|
||||
@@ -111,42 +113,18 @@ async def get_chat_messages_paginated(
|
||||
# expand backward to include the preceding assistant message that
|
||||
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
|
||||
# can pair them correctly.
|
||||
_BOUNDARY_SCAN_LIMIT = 10
|
||||
if results and results[0].role == "tool":
|
||||
boundary_where: dict[str, Any] = {
|
||||
"sessionId": session_id,
|
||||
"sequence": {"lt": results[0].sequence},
|
||||
}
|
||||
if user_id is not None:
|
||||
boundary_where["Session"] = {"is": {"userId": user_id}}
|
||||
extra = await PrismaChatMessage.prisma().find_many(
|
||||
where=boundary_where,
|
||||
order={"sequence": "desc"},
|
||||
take=_BOUNDARY_SCAN_LIMIT,
|
||||
results, has_more = await _expand_tool_boundary(
|
||||
session_id, results, has_more, user_id
|
||||
)
|
||||
|
||||
# Visibility guarantee: if the entire page has no user/assistant messages
|
||||
# (all tool messages), the chat would appear blank. Expand backward
|
||||
# until we find at least one visible message.
|
||||
if results and not any(m.role in ("user", "assistant") for m in results):
|
||||
results, has_more = await _expand_for_visibility(
|
||||
session_id, results, has_more, user_id
|
||||
)
|
||||
# Find the first non-tool message (should be the assistant)
|
||||
boundary_msgs = []
|
||||
found_owner = False
|
||||
for msg in extra:
|
||||
boundary_msgs.append(msg)
|
||||
if msg.role != "tool":
|
||||
found_owner = True
|
||||
break
|
||||
boundary_msgs.reverse()
|
||||
if not found_owner:
|
||||
logger.warning(
|
||||
"Boundary expansion did not find owning assistant message "
|
||||
"for session=%s before sequence=%s (%d msgs scanned)",
|
||||
session_id,
|
||||
results[0].sequence,
|
||||
len(extra),
|
||||
)
|
||||
if boundary_msgs:
|
||||
results = boundary_msgs + results
|
||||
# Only mark has_more if the expanded boundary isn't the
|
||||
# very start of the conversation (sequence 0).
|
||||
if boundary_msgs[0].sequence > 0:
|
||||
has_more = True
|
||||
|
||||
messages = [ChatMessage.from_db(m) for m in results]
|
||||
oldest_sequence = messages[0].sequence if messages else None
|
||||
@@ -159,6 +137,98 @@ async def get_chat_messages_paginated(
|
||||
)
|
||||
|
||||
|
||||
async def _expand_tool_boundary(
|
||||
session_id: str,
|
||||
results: list[Any],
|
||||
has_more: bool,
|
||||
user_id: str | None,
|
||||
) -> tuple[list[Any], bool]:
|
||||
"""Expand backward from the oldest message to include the owning assistant
|
||||
message when the page starts mid-tool-group."""
|
||||
boundary_where: ChatMessageWhereInput = {
|
||||
"sessionId": session_id,
|
||||
"sequence": {"lt": results[0].sequence},
|
||||
}
|
||||
if user_id is not None:
|
||||
boundary_where["Session"] = {"is": {"userId": user_id}}
|
||||
extra = await PrismaChatMessage.prisma().find_many(
|
||||
where=boundary_where,
|
||||
order={"sequence": "desc"},
|
||||
take=_BOUNDARY_SCAN_LIMIT,
|
||||
)
|
||||
# Find the first non-tool message (should be the assistant)
|
||||
boundary_msgs = []
|
||||
found_owner = False
|
||||
for msg in extra:
|
||||
boundary_msgs.append(msg)
|
||||
if msg.role != "tool":
|
||||
found_owner = True
|
||||
break
|
||||
boundary_msgs.reverse()
|
||||
if not found_owner:
|
||||
logger.warning(
|
||||
"Boundary expansion did not find owning assistant message "
|
||||
"for session=%s before sequence=%s (%d msgs scanned)",
|
||||
session_id,
|
||||
results[0].sequence,
|
||||
len(extra),
|
||||
)
|
||||
if boundary_msgs:
|
||||
results = boundary_msgs + results
|
||||
has_more = boundary_msgs[0].sequence > 0
|
||||
return results, has_more
|
||||
|
||||
|
||||
_VISIBILITY_EXPAND_LIMIT = 200
|
||||
|
||||
|
||||
async def _expand_for_visibility(
|
||||
session_id: str,
|
||||
results: list[Any],
|
||||
has_more: bool,
|
||||
user_id: str | None,
|
||||
) -> tuple[list[Any], bool]:
|
||||
"""Expand backward until the page contains at least one user or assistant
|
||||
message, so the chat is never blank."""
|
||||
expand_where: ChatMessageWhereInput = {
|
||||
"sessionId": session_id,
|
||||
"sequence": {"lt": results[0].sequence},
|
||||
}
|
||||
if user_id is not None:
|
||||
expand_where["Session"] = {"is": {"userId": user_id}}
|
||||
extra = await PrismaChatMessage.prisma().find_many(
|
||||
where=expand_where,
|
||||
order={"sequence": "desc"},
|
||||
take=_VISIBILITY_EXPAND_LIMIT,
|
||||
)
|
||||
if not extra:
|
||||
return results, has_more
|
||||
|
||||
# Collect messages until we find a visible one (user/assistant)
|
||||
prepend = []
|
||||
found_visible = False
|
||||
for msg in extra:
|
||||
prepend.append(msg)
|
||||
if msg.role in ("user", "assistant"):
|
||||
found_visible = True
|
||||
break
|
||||
|
||||
if not found_visible:
|
||||
logger.warning(
|
||||
"Visibility expansion did not find any user/assistant message "
|
||||
"for session=%s before sequence=%s (%d msgs scanned)",
|
||||
session_id,
|
||||
results[0].sequence,
|
||||
len(extra),
|
||||
)
|
||||
|
||||
prepend.reverse()
|
||||
if prepend:
|
||||
results = prepend + results
|
||||
has_more = prepend[0].sequence > 0
|
||||
return results, has_more
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
|
||||
@@ -175,6 +175,138 @@ async def test_no_where_on_messages_without_before_sequence(
|
||||
assert "where" not in include["Messages"]
|
||||
|
||||
|
||||
# ---------- Visibility guarantee ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visibility_expands_when_all_tool_messages(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""When the entire page is tool messages, expand backward to find
|
||||
at least one visible (user/assistant) message so the chat isn't blank."""
|
||||
find_first, find_many = mock_db
|
||||
# Newest 3 messages are all tool messages (DESC → reversed to ASC)
|
||||
find_first.return_value = _make_session(
|
||||
messages=[
|
||||
_make_msg(12, role="tool"),
|
||||
_make_msg(11, role="tool"),
|
||||
_make_msg(10, role="tool"),
|
||||
],
|
||||
)
|
||||
# Boundary expansion finds the owning assistant first (boundary fix),
|
||||
# then visibility expansion finds a user message further back
|
||||
find_many.side_effect = [
|
||||
# First call: boundary fix (oldest msg is tool → find owner)
|
||||
[_make_msg(9, role="tool"), _make_msg(8, role="tool")],
|
||||
# Second call: visibility expansion (still all tool → find visible)
|
||||
[_make_msg(7, role="tool"), _make_msg(6, role="assistant")],
|
||||
]
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=3)
|
||||
|
||||
assert page is not None
|
||||
# Should include the expanded messages + original tool messages
|
||||
roles = [m.role for m in page.messages]
|
||||
assert "assistant" in roles or "user" in roles
|
||||
assert page.has_more is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_visibility_expansion_when_visible_messages_present(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""No visibility expansion needed when page already has visible messages."""
|
||||
find_first, find_many = mock_db
|
||||
# Page has an assistant message among tool messages
|
||||
find_first.return_value = _make_session(
|
||||
messages=[
|
||||
_make_msg(5, role="tool"),
|
||||
_make_msg(4, role="assistant"),
|
||||
_make_msg(3, role="user"),
|
||||
],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=3)
|
||||
|
||||
assert page is not None
|
||||
# Boundary expansion might fire (oldest is tool), but NOT visibility
|
||||
assert [m.sequence for m in page.messages][0] <= 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visibility_no_expansion_when_no_earlier_messages(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""When the page is all tool messages but there are no earlier messages
|
||||
in the DB, visibility expansion returns early without changes."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(1, role="tool"), _make_msg(0, role="tool")],
|
||||
)
|
||||
# Boundary expansion: no earlier messages
|
||||
# Visibility expansion: no earlier messages
|
||||
find_many.side_effect = [[], []]
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=2)
|
||||
|
||||
assert page is not None
|
||||
assert all(m.role == "tool" for m in page.messages)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visibility_expansion_reaches_seq_zero(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""When visibility expansion finds a visible message at sequence 0,
|
||||
has_more should be False."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(5, role="tool"), _make_msg(4, role="tool")],
|
||||
)
|
||||
find_many.side_effect = [
|
||||
# Boundary expansion
|
||||
[_make_msg(3, role="tool")],
|
||||
# Visibility expansion — finds user at seq 0
|
||||
[
|
||||
_make_msg(2, role="tool"),
|
||||
_make_msg(1, role="tool"),
|
||||
_make_msg(0, role="user"),
|
||||
],
|
||||
]
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=2)
|
||||
|
||||
assert page is not None
|
||||
assert page.messages[0].role == "user"
|
||||
assert page.messages[0].sequence == 0
|
||||
assert page.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visibility_expansion_with_user_id(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""Visibility expansion passes user_id filter to the boundary query."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(10, role="tool")],
|
||||
)
|
||||
find_many.side_effect = [
|
||||
# Boundary expansion
|
||||
[_make_msg(9, role="tool")],
|
||||
# Visibility expansion
|
||||
[_make_msg(8, role="assistant")],
|
||||
]
|
||||
|
||||
await get_chat_messages_paginated(SESSION_ID, limit=1, user_id="user-abc")
|
||||
|
||||
# Both find_many calls should include the user_id session filter
|
||||
for call in find_many.call_args_list:
|
||||
where = call.kwargs.get("where") or call[1].get("where")
|
||||
assert "Session" in where
|
||||
assert where["Session"] == {"is": {"userId": "user-abc"}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_id_filter_applied_to_session_where(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
@@ -329,7 +461,8 @@ async def test_boundary_expansion_warns_when_no_owner_found(
|
||||
|
||||
with patch("backend.copilot.db.logger") as mock_logger:
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
|
||||
mock_logger.warning.assert_called_once()
|
||||
# Two warnings: boundary expansion + visibility expansion (all tool msgs)
|
||||
assert mock_logger.warning.call_count == 2
|
||||
|
||||
assert page is not None
|
||||
assert page.messages[0].role == "tool"
|
||||
|
||||
@@ -34,6 +34,7 @@ from .utils import (
|
||||
CancelCoPilotEvent,
|
||||
CoPilotExecutionEntry,
|
||||
create_copilot_queue_config,
|
||||
get_session_lock_key,
|
||||
)
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||
@@ -366,7 +367,7 @@ class CoPilotExecutor(AppProcess):
|
||||
# Try to acquire cluster-wide lock
|
||||
cluster_lock = ClusterLock(
|
||||
redis=redis.get_redis(),
|
||||
key=f"copilot:session:{session_id}:lock",
|
||||
key=get_session_lock_key(session_id),
|
||||
owner_id=self.executor_id,
|
||||
timeout=settings.config.cluster_lock_timeout,
|
||||
)
|
||||
|
||||
@@ -222,6 +222,10 @@ class CoPilotProcessor:
|
||||
Shuts down the workspace storage instance that belongs to this
|
||||
worker's event loop, ensuring ``aiohttp.ClientSession.close()``
|
||||
runs on the same loop that created the session.
|
||||
|
||||
Sub-AutoPilots are enqueued on the copilot_execution queue, so
|
||||
rolling deploys survive via RabbitMQ redelivery — no bespoke
|
||||
shutdown notifier needed.
|
||||
"""
|
||||
coro = shutdown_workspace_storage()
|
||||
try:
|
||||
@@ -342,7 +346,9 @@ class CoPilotProcessor:
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
# stream_and_publish wraps the raw stream with registry
|
||||
# publishing (shared with collect_copilot_response).
|
||||
# publishing so subscribers on the session Redis stream
|
||||
# (e.g. wait_for_session_result, SSE clients) receive the
|
||||
# same events as they are produced.
|
||||
raw_stream = stream_fn(
|
||||
session_id=entry.session_id,
|
||||
message=entry.message if entry.message else None,
|
||||
@@ -352,27 +358,37 @@ class CoPilotProcessor:
|
||||
file_ids=entry.file_ids,
|
||||
mode=effective_mode,
|
||||
model=entry.model,
|
||||
permissions=entry.permissions,
|
||||
request_arrival_at=entry.request_arrival_at,
|
||||
)
|
||||
async for chunk in stream_registry.stream_and_publish(
|
||||
published_stream = stream_registry.stream_and_publish(
|
||||
session_id=entry.session_id,
|
||||
turn_id=entry.turn_id,
|
||||
stream=raw_stream,
|
||||
):
|
||||
if cancel.is_set():
|
||||
log.info("Cancel requested, breaking stream")
|
||||
break
|
||||
)
|
||||
# Explicit aclose() on early exit: ``async for … break`` does
|
||||
# not close the generator, so GeneratorExit would never reach
|
||||
# stream_chat_completion_sdk, leaving its stream lock held
|
||||
# until GC eventually runs.
|
||||
try:
|
||||
async for chunk in published_stream:
|
||||
if cancel.is_set():
|
||||
log.info("Cancel requested, breaking stream")
|
||||
break
|
||||
|
||||
# Capture StreamError so mark_session_completed receives
|
||||
# the error message (stream_and_publish yields but does
|
||||
# not publish StreamError — that's done by mark_session_completed).
|
||||
if isinstance(chunk, StreamError):
|
||||
error_msg = chunk.errorText
|
||||
break
|
||||
# Capture StreamError so mark_session_completed receives
|
||||
# the error message (stream_and_publish yields but does
|
||||
# not publish StreamError — that's done by mark_session_completed).
|
||||
if isinstance(chunk, StreamError):
|
||||
error_msg = chunk.errorText
|
||||
break
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= refresh_interval:
|
||||
cluster_lock.refresh()
|
||||
last_refresh = current_time
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= refresh_interval:
|
||||
cluster_lock.refresh()
|
||||
last_refresh = current_time
|
||||
finally:
|
||||
await published_stream.aclose()
|
||||
|
||||
# Stream loop completed
|
||||
if cancel.is_set():
|
||||
|
||||
@@ -10,14 +10,18 @@ the real production helpers from ``processor.py`` so the routing logic
|
||||
has meaningful coverage.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
import logging
|
||||
import threading
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.executor.processor import (
|
||||
CoPilotProcessor,
|
||||
resolve_effective_mode,
|
||||
resolve_use_sdk_for_mode,
|
||||
)
|
||||
from backend.copilot.executor.utils import CoPilotExecutionEntry, CoPilotLogMetadata
|
||||
|
||||
|
||||
class TestResolveUseSdkForMode:
|
||||
@@ -173,3 +177,101 @@ class TestResolveEffectiveMode:
|
||||
) as flag_mock:
|
||||
assert await resolve_effective_mode("fast", None) is None
|
||||
flag_mock.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _execute_async aclose propagation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _TrackedStream:
|
||||
"""Minimal async-generator stand-in that records whether ``aclose``
|
||||
was called, so tests can verify the processor forces explicit cleanup
|
||||
of the published stream on every exit path (normal + break on cancel)."""
|
||||
|
||||
def __init__(self, events: list):
|
||||
self._events = events
|
||||
self.aclose_called = False
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self._events:
|
||||
raise StopAsyncIteration
|
||||
return self._events.pop(0)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.aclose_called = True
|
||||
|
||||
|
||||
def _make_entry() -> CoPilotExecutionEntry:
|
||||
return CoPilotExecutionEntry(
|
||||
session_id="sess-1",
|
||||
turn_id="turn-1",
|
||||
user_id="user-1",
|
||||
message="hi",
|
||||
is_user_message=True,
|
||||
request_arrival_at=0.0,
|
||||
)
|
||||
|
||||
|
||||
def _make_log() -> CoPilotLogMetadata:
|
||||
return CoPilotLogMetadata(logger=logging.getLogger("test-copilot"))
|
||||
|
||||
|
||||
class TestExecuteAsyncAclose:
|
||||
"""``_execute_async`` must call ``aclose`` on the published stream both
|
||||
when the loop exits naturally and when ``cancel`` is set mid-stream —
|
||||
otherwise ``stream_chat_completion_sdk`` stays suspended and keeps
|
||||
holding the per-session Redis lock until GC."""
|
||||
|
||||
def _patches(self, published_stream: _TrackedStream):
|
||||
"""Shared mock context: patches every dependency ``_execute_async``
|
||||
touches so the aclose path is the only behaviour under test."""
|
||||
return [
|
||||
patch(
|
||||
"backend.copilot.executor.processor.ChatConfig",
|
||||
return_value=MagicMock(test_mode=True, use_claude_agent_sdk=True),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.executor.processor.stream_chat_completion_dummy",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.executor.processor.stream_registry.stream_and_publish",
|
||||
return_value=published_stream,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
|
||||
new=AsyncMock(),
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normal_exit_calls_aclose(self) -> None:
|
||||
published = _TrackedStream(events=[MagicMock(), MagicMock()])
|
||||
proc = CoPilotProcessor()
|
||||
cancel = threading.Event()
|
||||
cluster_lock = MagicMock()
|
||||
|
||||
patches = self._patches(published)
|
||||
with patches[0], patches[1], patches[2], patches[3]:
|
||||
await proc._execute_async(_make_entry(), cancel, cluster_lock, _make_log())
|
||||
|
||||
assert published.aclose_called is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_break_calls_aclose(self) -> None:
|
||||
events = [MagicMock()] # first chunk delivered, then cancel fires
|
||||
published = _TrackedStream(events=events)
|
||||
proc = CoPilotProcessor()
|
||||
cancel = threading.Event()
|
||||
cancel.set() # pre-set so the loop breaks on the first chunk
|
||||
cluster_lock = MagicMock()
|
||||
|
||||
patches = self._patches(published)
|
||||
with patches[0], patches[1], patches[2], patches[3]:
|
||||
await proc._execute_async(_make_entry(), cancel, cluster_lock, _make_log())
|
||||
|
||||
assert published.aclose_called is True
|
||||
|
||||
@@ -10,6 +10,7 @@ import logging
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.config import CopilotLlmModel, CopilotMode
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||
|
||||
@@ -81,6 +82,12 @@ COPILOT_CANCEL_EXCHANGE = Exchange(
|
||||
)
|
||||
COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue"
|
||||
|
||||
|
||||
def get_session_lock_key(session_id: str) -> str:
|
||||
"""Redis key for the per-session cluster lock held by the executing pod."""
|
||||
return f"copilot:session:{session_id}:lock"
|
||||
|
||||
|
||||
# CoPilot operations can include extended thinking and agent generation
|
||||
# which may take 30+ minutes to complete
|
||||
COPILOT_CONSUMER_TIMEOUT_SECONDS = 60 * 60 # 1 hour
|
||||
@@ -163,6 +170,20 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
model: CopilotLlmModel | None = None
|
||||
"""Per-request model tier: 'standard' or 'advanced'. None = server default."""
|
||||
|
||||
permissions: CopilotPermissions | None = None
|
||||
"""Capability filter inherited from a parent run (e.g. ``run_sub_session``
|
||||
forwards its parent's permissions so the sub can't escalate). ``None``
|
||||
means the worker applies no filter."""
|
||||
|
||||
request_arrival_at: float = 0.0
|
||||
"""Unix-epoch seconds (server clock) when the originating HTTP
|
||||
``/stream`` request arrived. The executor's turn-start drain uses
|
||||
this to decide whether each pending message was typed BEFORE or AFTER
|
||||
the turn's ``current`` message, and orders the combined user bubble
|
||||
chronologically. Defaults to ``0.0`` for backward compatibility with
|
||||
queue messages written before this field existed (they sort as "all
|
||||
pending before current" — the pre-fix behaviour)."""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
@@ -184,6 +205,8 @@ async def enqueue_copilot_turn(
|
||||
file_ids: list[str] | None = None,
|
||||
mode: CopilotMode | None = None,
|
||||
model: CopilotLlmModel | None = None,
|
||||
permissions: CopilotPermissions | None = None,
|
||||
request_arrival_at: float = 0.0,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
@@ -197,6 +220,8 @@ async def enqueue_copilot_turn(
|
||||
file_ids: Optional workspace file IDs attached to the user's message
|
||||
mode: Autopilot mode override ('fast' or 'extended_thinking'). None = server default.
|
||||
model: Per-request model tier ('standard' or 'advanced'). None = server default.
|
||||
permissions: Capability filter inherited from a parent run (sub-AutoPilot).
|
||||
None = no filter.
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
@@ -210,6 +235,8 @@ async def enqueue_copilot_turn(
|
||||
file_ids=file_ids,
|
||||
mode=mode,
|
||||
model=model,
|
||||
permissions=permissions,
|
||||
request_arrival_at=request_arrival_at,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
"""Per-request idempotency lock for the /stream endpoint.
|
||||
|
||||
Prevents duplicate executor tasks from concurrent or retried POSTs (e.g. k8s
|
||||
rolling-deploy retries, nginx upstream retries, rapid double-clicks).
|
||||
|
||||
Lifecycle
|
||||
---------
|
||||
1. ``acquire()`` — computes a stable hash of (session_id, message, file_ids)
|
||||
and atomically sets a Redis NX key. Returns a ``_DedupLock`` on success or
|
||||
``None`` when the key already exists (duplicate request).
|
||||
2. ``release()`` — deletes the key. Must be called on turn completion or turn
|
||||
error so the next legitimate send is never blocked.
|
||||
3. On client disconnect (``GeneratorExit``) the lock must NOT be released —
|
||||
the backend turn is still running, and releasing would reopen the duplicate
|
||||
window for infra-level retries. The 30 s TTL is the safety net.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_KEY_PREFIX = "chat:msg_dedup"
|
||||
_TTL_SECONDS = 30
|
||||
|
||||
|
||||
class _DedupLock:
|
||||
def __init__(self, key: str, redis) -> None:
|
||||
self._key = key
|
||||
self._redis = redis
|
||||
|
||||
async def release(self) -> None:
|
||||
"""Best-effort key deletion. The TTL handles failures silently."""
|
||||
try:
|
||||
await self._redis.delete(self._key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def acquire_dedup_lock(
|
||||
session_id: str,
|
||||
message: str | None,
|
||||
file_ids: list[str] | None,
|
||||
) -> _DedupLock | None:
|
||||
"""Acquire the idempotency lock for this (session, message, files) tuple.
|
||||
|
||||
Returns a ``_DedupLock`` when the lock is freshly acquired (first request).
|
||||
Returns ``None`` when a duplicate is detected (lock already held).
|
||||
Returns ``None`` when there is nothing to deduplicate (no message, no files).
|
||||
"""
|
||||
if not message and not file_ids:
|
||||
return None
|
||||
|
||||
sorted_ids = ":".join(sorted(file_ids or []))
|
||||
content_hash = hashlib.sha256(
|
||||
f"{session_id}:{message or ''}:{sorted_ids}".encode()
|
||||
).hexdigest()[:16]
|
||||
key = f"{_KEY_PREFIX}:{session_id}:{content_hash}"
|
||||
|
||||
redis = await get_redis_async()
|
||||
acquired = await redis.set(key, "1", ex=_TTL_SECONDS, nx=True)
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
f"[STREAM] Duplicate user message blocked for session {session_id}, "
|
||||
f"hash={content_hash} — returning empty SSE",
|
||||
)
|
||||
return None
|
||||
|
||||
return _DedupLock(key, redis)
|
||||
@@ -1,94 +0,0 @@
|
||||
"""Unit tests for backend.copilot.message_dedup."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.copilot.message_dedup import _KEY_PREFIX, acquire_dedup_lock
|
||||
|
||||
|
||||
def _patch_redis(mocker: pytest_mock.MockerFixture, *, set_returns):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=set_returns)
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
return mock_redis
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_returns_none_when_no_message_no_files(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Nothing to deduplicate — no Redis call made, None returned."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
result = await acquire_dedup_lock("sess-1", None, None)
|
||||
assert result is None
|
||||
mock_redis.set.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_returns_lock_on_first_request(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""First request acquires the lock and returns a _DedupLock."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
lock = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert lock is not None
|
||||
mock_redis.set.assert_called_once()
|
||||
key_arg = mock_redis.set.call_args.args[0]
|
||||
assert key_arg.startswith(f"{_KEY_PREFIX}:sess-1:")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_returns_none_on_duplicate(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Duplicate request (NX fails) returns None to signal the caller."""
|
||||
_patch_redis(mocker, set_returns=None)
|
||||
result = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_key_stable_across_file_order(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""File IDs are sorted before hashing so order doesn't affect the key."""
|
||||
mock_redis_1 = _patch_redis(mocker, set_returns=True)
|
||||
await acquire_dedup_lock("sess-1", "msg", ["b", "a"])
|
||||
key_ab = mock_redis_1.set.call_args.args[0]
|
||||
|
||||
mock_redis_2 = _patch_redis(mocker, set_returns=True)
|
||||
await acquire_dedup_lock("sess-1", "msg", ["a", "b"])
|
||||
key_ba = mock_redis_2.set.call_args.args[0]
|
||||
|
||||
assert key_ab == key_ba
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_deletes_key(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""release() calls Redis delete exactly once."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
lock = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert lock is not None
|
||||
await lock.release()
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_swallows_redis_error(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""release() must not raise even when Redis delete fails."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
mock_redis.delete = AsyncMock(side_effect=RuntimeError("redis down"))
|
||||
lock = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert lock is not None
|
||||
await lock.release() # must not raise
|
||||
mock_redis.delete.assert_called_once()
|
||||
@@ -1,9 +1,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Self, cast
|
||||
from weakref import WeakValueDictionary
|
||||
from typing import Any, AsyncIterator, Self, cast
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
@@ -522,10 +521,7 @@ async def upsert_chat_session(
|
||||
callers are aware of the persistence failure.
|
||||
RedisError: If the cache write fails (after successful DB write).
|
||||
"""
|
||||
# Acquire session-specific lock to prevent concurrent upserts
|
||||
lock = await _get_session_lock(session.session_id)
|
||||
|
||||
async with lock:
|
||||
async with _get_session_lock(session.session_id) as _:
|
||||
# Always query DB for existing message count to ensure consistency
|
||||
existing_message_count = await chat_db().get_next_sequence(session.session_id)
|
||||
|
||||
@@ -651,20 +647,50 @@ async def _save_session_to_db(
|
||||
msg.sequence = existing_message_count + i
|
||||
|
||||
|
||||
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
||||
async def append_and_save_message(
|
||||
session_id: str, message: ChatMessage
|
||||
) -> ChatSession | None:
|
||||
"""Atomically append a message to a session and persist it.
|
||||
|
||||
Acquires the session lock, re-fetches the latest session state,
|
||||
appends the message, and saves — preventing message loss when
|
||||
concurrent requests modify the same session.
|
||||
"""
|
||||
lock = await _get_session_lock(session_id)
|
||||
Returns the updated session, or None if the message was detected as a
|
||||
duplicate (idempotency guard). Callers must check for None and skip any
|
||||
downstream work (e.g. enqueuing a new LLM turn) when a duplicate is detected.
|
||||
|
||||
async with lock:
|
||||
session = await get_chat_session(session_id)
|
||||
Uses _get_session_lock (Redis NX) to serialise concurrent writers across replicas.
|
||||
The idempotency check below provides a last-resort guard when the lock degrades.
|
||||
"""
|
||||
async with _get_session_lock(session_id) as lock_acquired:
|
||||
# When the lock degraded (Redis down or 2s timeout), bypass cache for
|
||||
# the idempotency check. Stale cache could let two concurrent writers
|
||||
# both see the old state, pass the check, and write the same message.
|
||||
if lock_acquired:
|
||||
session = await get_chat_session(session_id)
|
||||
else:
|
||||
session = await _get_session_from_db(session_id)
|
||||
if session is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
# Idempotency: skip if the trailing block of same-role messages already
|
||||
# contains this content. Uses is_message_duplicate which checks all
|
||||
# consecutive trailing messages of the same role, not just [-1].
|
||||
#
|
||||
# This collapses infra/nginx retries whether they land on the same pod
|
||||
# (serialised by the Redis lock) or a different pod.
|
||||
#
|
||||
# Legit same-text messages are distinguished by the assistant turn
|
||||
# between them: if the user said "yes", got a response, and says
|
||||
# "yes" again, session.messages[-1] is the assistant reply, so the
|
||||
# role check fails and the second message goes through normally.
|
||||
#
|
||||
# Edge case: if a turn dies without writing any assistant message,
|
||||
# the user's next send of the same text is blocked here permanently.
|
||||
# The fix is to ensure failed turns always write an error/timeout
|
||||
# assistant message so the session always ends on an assistant turn.
|
||||
if message.content is not None and is_message_duplicate(
|
||||
session.messages, message.role, message.content
|
||||
):
|
||||
return None # duplicate — caller should skip enqueue
|
||||
|
||||
session.messages.append(message)
|
||||
existing_message_count = await chat_db().get_next_sequence(session_id)
|
||||
|
||||
@@ -679,6 +705,9 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache write failed for session {session_id}: {e}")
|
||||
# Invalidate the stale entry so future reads fall back to DB,
|
||||
# preventing a retry from bypassing the idempotency check above.
|
||||
await invalidate_session_cache(session_id)
|
||||
|
||||
return session
|
||||
|
||||
@@ -764,10 +793,6 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
|
||||
|
||||
# Clean up session lock (belt-and-suspenders with WeakValueDictionary)
|
||||
async with _session_locks_mutex:
|
||||
_session_locks.pop(session_id, None)
|
||||
|
||||
# Shut down any local browser daemon for this session (best-effort).
|
||||
# Inline import required: all tool modules import ChatSession from this
|
||||
# module, so any top-level import from tools.* would create a cycle.
|
||||
@@ -832,25 +857,38 @@ async def update_session_title(
|
||||
|
||||
# ==================== Chat session locks ==================== #
|
||||
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _get_session_lock(session_id: str) -> AsyncIterator[bool]:
|
||||
"""Distributed Redis lock for a session, usable as an async context manager.
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
Yields True if the lock was acquired, False if it timed out or Redis was
|
||||
unavailable. Callers should treat False as a degraded mode and prefer fresh
|
||||
DB reads over cache to avoid acting on stale state.
|
||||
|
||||
This was originally added to solve the specific problem of race conditions between
|
||||
the session title thread and the conversation thread, which always occurs on the
|
||||
same instance as we prevent rapid request sends on the frontend.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks. Explicit cleanup also occurs
|
||||
in `delete_chat_session()`.
|
||||
Uses redis-py's built-in Lock (Lua-script acquire/release) so lock acquisition
|
||||
is atomic and release is owner-verified. Blocks up to 2s for a concurrent
|
||||
writer to finish; the 10s TTL ensures a dead pod never holds the lock forever.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
_lock_key = f"copilot:session_lock:{session_id}"
|
||||
lock = None
|
||||
acquired = False
|
||||
try:
|
||||
_redis = await get_redis_async()
|
||||
lock = _redis.lock(_lock_key, timeout=10, blocking_timeout=2)
|
||||
acquired = await lock.acquire(blocking=True)
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
"Could not acquire session lock for %s within 2s", session_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Redis unavailable for session lock on %s: %s", session_id, e)
|
||||
|
||||
try:
|
||||
yield acquired
|
||||
finally:
|
||||
if acquired and lock is not None:
|
||||
try:
|
||||
await lock.release()
|
||||
except Exception:
|
||||
pass # TTL will expire the key
|
||||
|
||||
@@ -11,11 +11,13 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
ChatCompletionMessageToolCallParam,
|
||||
Function,
|
||||
)
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
append_and_save_message,
|
||||
get_chat_session,
|
||||
is_message_duplicate,
|
||||
maybe_append_user_message,
|
||||
@@ -574,3 +576,345 @@ def test_maybe_append_assistant_skips_duplicate():
|
||||
result = maybe_append_user_message(session, "dup", is_user_message=False)
|
||||
assert result is False
|
||||
assert len(session.messages) == 2
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# append_and_save_message #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def _make_session_with_messages(*msgs: ChatMessage) -> ChatSession:
|
||||
s = ChatSession.new(user_id="u1", dry_run=False)
|
||||
s.messages = list(msgs)
|
||||
return s
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_returns_none_for_duplicate(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""append_and_save_message returns None when the trailing message is a duplicate."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="user", content="hello"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
|
||||
result = await append_and_save_message(
|
||||
session.session_id, ChatMessage(role="user", content="hello")
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_appends_new_message(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""append_and_save_message appends a non-duplicate message and returns the session."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="user", content="hello"),
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=2)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
new_msg = ChatMessage(role="user", content="second message")
|
||||
result = await append_and_save_message(session.session_id, new_msg)
|
||||
assert result is not None
|
||||
assert result.messages[-1].content == "second message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_raises_when_session_not_found(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""append_and_save_message raises ValueError when the session does not exist."""
|
||||
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await append_and_save_message(
|
||||
"missing-session-id", ChatMessage(role="user", content="hi")
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_uses_db_when_lock_degraded(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""When the Redis lock times out (acquired=False), the fallback reads from DB."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=False)
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mock_get_from_db = mocker.patch(
|
||||
"backend.copilot.model._get_session_from_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
new_msg = ChatMessage(role="user", content="new msg")
|
||||
result = await append_and_save_message(session.session_id, new_msg)
|
||||
# DB path was used (not cache-first)
|
||||
mock_get_from_db.assert_called_once_with(session.session_id)
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_raises_database_error_on_save_failure(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""When _save_session_to_db fails, append_and_save_message raises DatabaseError."""
|
||||
from backend.util.exceptions import DatabaseError
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
side_effect=RuntimeError("db down"),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(DatabaseError):
|
||||
await append_and_save_message(
|
||||
session.session_id, ChatMessage(role="user", content="new msg")
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_invalidates_cache_on_cache_failure(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""When cache_chat_session fails, invalidate_session_cache is called to avoid stale reads."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
side_effect=RuntimeError("redis write failed"),
|
||||
)
|
||||
mock_invalidate = mocker.patch(
|
||||
"backend.copilot.model.invalidate_session_cache",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
result = await append_and_save_message(
|
||||
session.session_id, ChatMessage(role="user", content="new msg")
|
||||
)
|
||||
# DB write succeeded, cache invalidation was called
|
||||
mock_invalidate.assert_called_once_with(session.session_id)
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_uses_db_when_redis_unavailable(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""When get_redis_async raises, _get_session_lock yields False (degraded) and DB is read."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
side_effect=ConnectionError("redis down"),
|
||||
)
|
||||
mock_get_from_db = mocker.patch(
|
||||
"backend.copilot.model._get_session_from_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
new_msg = ChatMessage(role="user", content="new msg")
|
||||
result = await append_and_save_message(session.session_id, new_msg)
|
||||
mock_get_from_db.assert_called_once_with(session.session_id)
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_lock_release_failure_is_ignored(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""If lock.release() raises, the exception is swallowed (TTL will clean up)."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock(
|
||||
side_effect=RuntimeError("release failed")
|
||||
)
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
new_msg = ChatMessage(role="user", content="new msg")
|
||||
result = await append_and_save_message(session.session_id, new_msg)
|
||||
assert result is not None
|
||||
|
||||
@@ -0,0 +1,384 @@
|
||||
"""Shared helpers for draining and injecting pending messages.
|
||||
|
||||
Used by both the baseline and SDK copilot paths to avoid duplicating
|
||||
the try/except drain, format, insert, and persist patterns.
|
||||
|
||||
Also provides the call-rate-limit check for the queue endpoint so
|
||||
routes.py stays free of Redis/Lua details.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.model import ChatMessage, upsert_chat_session
|
||||
from backend.copilot.pending_messages import (
|
||||
MAX_PENDING_MESSAGES,
|
||||
PendingMessage,
|
||||
PendingMessageContext,
|
||||
drain_pending_messages,
|
||||
format_pending_as_user_message,
|
||||
push_pending_message,
|
||||
)
|
||||
from backend.copilot.stream_registry import get_session as get_active_session_meta
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.redis_helpers import incr_with_ttl
|
||||
from backend.data.workspace import resolve_workspace_files
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Call-frequency cap for the pending-message endpoint. The token-budget
|
||||
# check guards against overspend but not rapid-fire pushes from a client
|
||||
# with a large budget.
|
||||
PENDING_CALL_LIMIT = 30
|
||||
PENDING_CALL_WINDOW_SECONDS = 60
|
||||
_PENDING_CALL_KEY_PREFIX = "copilot:pending:calls:"
|
||||
|
||||
|
||||
async def is_turn_in_flight(session_id: str) -> bool:
|
||||
"""Return ``True`` when a copilot turn is actively running for *session_id*.
|
||||
|
||||
Used by the unified POST /stream entry point and the autopilot block so
|
||||
a second message arriving while an earlier turn is still executing gets
|
||||
queued into the pending buffer instead of racing the in-flight turn on
|
||||
the cluster lock.
|
||||
"""
|
||||
active = await get_active_session_meta(session_id)
|
||||
return active is not None and active.status == "running"
|
||||
|
||||
|
||||
class QueuePendingMessageResponse(BaseModel):
|
||||
"""Response returned by ``POST /stream`` with status 202 when a message
|
||||
is queued because the session already has a turn in flight.
|
||||
|
||||
- ``buffer_length``: how many messages are now in the session's
|
||||
pending buffer (after this push)
|
||||
- ``max_buffer_length``: the per-session cap (server-side constant)
|
||||
- ``turn_in_flight``: ``True`` if a copilot turn was running when
|
||||
we checked — purely informational for UX feedback. Always ``True``
|
||||
for responses from ``POST /stream`` with status 202.
|
||||
"""
|
||||
|
||||
buffer_length: int
|
||||
max_buffer_length: int
|
||||
turn_in_flight: bool
|
||||
|
||||
|
||||
async def queue_user_message(
|
||||
*,
|
||||
session_id: str,
|
||||
message: str,
|
||||
context: PendingMessageContext | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
) -> QueuePendingMessageResponse:
|
||||
"""Push *message* into the per-session pending buffer.
|
||||
|
||||
The shared primitive for "a message arrived while a turn is in flight" —
|
||||
called from the unified POST /stream handler and the autopilot block.
|
||||
Call-frequency rate limiting is the caller's responsibility (HTTP path
|
||||
enforces it; internal block callers skip it).
|
||||
"""
|
||||
pending = PendingMessage(
|
||||
content=message,
|
||||
file_ids=file_ids or [],
|
||||
context=context,
|
||||
)
|
||||
new_len = await push_pending_message(session_id, pending)
|
||||
return QueuePendingMessageResponse(
|
||||
buffer_length=new_len,
|
||||
max_buffer_length=MAX_PENDING_MESSAGES,
|
||||
turn_in_flight=await is_turn_in_flight(session_id),
|
||||
)
|
||||
|
||||
|
||||
async def queue_pending_for_http(
|
||||
*,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
message: str,
|
||||
context: dict[str, str] | None,
|
||||
file_ids: list[str] | None,
|
||||
) -> QueuePendingMessageResponse:
|
||||
"""HTTP-facing wrapper around :func:`queue_user_message`.
|
||||
|
||||
Owns the HTTP-only concerns that sat inline in ``stream_chat_post``:
|
||||
|
||||
1. Per-user call-rate cap (429 on overflow).
|
||||
2. File-ID sanitisation against the user's own workspace.
|
||||
3. ``{url, content}`` dict → ``PendingMessageContext`` coercion.
|
||||
4. Push via ``queue_user_message``.
|
||||
|
||||
Raises :class:`HTTPException` with status 429 if the rate cap is hit;
|
||||
otherwise returns the ``QueuePendingMessageResponse`` the handler can
|
||||
serialise 1:1 into the 202 body.
|
||||
"""
|
||||
call_count = await check_pending_call_rate(user_id)
|
||||
if call_count > PENDING_CALL_LIMIT:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=(
|
||||
f"Too many queued message requests this minute: limit is "
|
||||
f"{PENDING_CALL_LIMIT} per {PENDING_CALL_WINDOW_SECONDS}s "
|
||||
"across all sessions"
|
||||
),
|
||||
)
|
||||
|
||||
sanitized_file_ids: list[str] | None = None
|
||||
if file_ids:
|
||||
files = await resolve_workspace_files(user_id, file_ids)
|
||||
sanitized_file_ids = [wf.id for wf in files] or None
|
||||
|
||||
# ``PendingMessageContext`` uses the default ``extra='ignore'`` so
|
||||
# unknown keys in the loose HTTP-level ``context`` dict are silently
|
||||
# dropped rather than raising ``ValidationError`` + 500ing (sentry
|
||||
# r3105553772). The strict mode would only help protect against
|
||||
# typos, but the upstream ``StreamChatRequest.context: dict[str, str]``
|
||||
# is already schemaless, so the strict mode adds no real safety.
|
||||
queue_context = PendingMessageContext.model_validate(context) if context else None
|
||||
return await queue_user_message(
|
||||
session_id=session_id,
|
||||
message=message,
|
||||
context=queue_context,
|
||||
file_ids=sanitized_file_ids,
|
||||
)
|
||||
|
||||
|
||||
async def check_pending_call_rate(user_id: str) -> int:
|
||||
"""Increment and return the per-user push counter for the current window.
|
||||
|
||||
The counter is **user-global**: it counts pushes across ALL sessions
|
||||
belonging to the user, not per-session. This prevents a client from
|
||||
bypassing the cap by spreading rapid pushes across many sessions.
|
||||
|
||||
Returns the new call count. Raises nothing — callers compare the
|
||||
return value against ``PENDING_CALL_LIMIT`` and decide what to do.
|
||||
Fails open (returns 0) if Redis is unavailable so the endpoint stays
|
||||
usable during Redis hiccups.
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
key = f"{_PENDING_CALL_KEY_PREFIX}{user_id}"
|
||||
return await incr_with_ttl(redis, key, PENDING_CALL_WINDOW_SECONDS)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"pending_message_helpers: call-rate check failed for user=%s, failing open",
|
||||
user_id,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
async def drain_pending_safe(
|
||||
session_id: str, log_prefix: str = ""
|
||||
) -> list[PendingMessage]:
|
||||
"""Drain the pending buffer and return the full ``PendingMessage`` objects.
|
||||
|
||||
Returns ``[]`` on any Redis error so callers can always treat the
|
||||
result as a plain list. Callers that only need the rendered string
|
||||
(turn-start injection, auto-continue combined prompt) wrap this with
|
||||
:func:`pending_texts_from` — we return the structured objects so the
|
||||
re-queue rollback path can preserve ``file_ids`` / ``context`` that
|
||||
would otherwise be stripped by a text-only conversion.
|
||||
"""
|
||||
try:
|
||||
return await drain_pending_messages(session_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"%s drain_pending_messages failed, skipping",
|
||||
log_prefix or "pending_messages",
|
||||
exc_info=True,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def pending_texts_from(pending: list[PendingMessage]) -> list[str]:
|
||||
"""Render a list of ``PendingMessage`` objects into plain text strings.
|
||||
|
||||
Shared helper for the two callers that need the rendered form:
|
||||
turn-start injection (bundles the pending block into the user prompt)
|
||||
and the auto-continue combined-message path.
|
||||
"""
|
||||
return [format_pending_as_user_message(pm)["content"] for pm in pending]
|
||||
|
||||
|
||||
def combine_pending_with_current(
|
||||
pending: list[PendingMessage],
|
||||
current_message: str | None,
|
||||
*,
|
||||
request_arrival_at: float,
|
||||
) -> str:
|
||||
"""Order pending messages around *current_message* by typing time.
|
||||
|
||||
Pending messages whose ``enqueued_at`` is strictly greater than
|
||||
``request_arrival_at`` were typed AFTER the user hit enter to start
|
||||
the current turn (the "race" path: queued into the pending buffer
|
||||
while ``/stream`` was still processing on the server). They belong
|
||||
chronologically AFTER the current message.
|
||||
|
||||
Pending messages whose ``enqueued_at`` is less than or equal to
|
||||
``request_arrival_at`` were typed BEFORE the current turn — usually
|
||||
from a prior in-flight window that auto-continue didn't consume.
|
||||
They belong BEFORE the current message.
|
||||
|
||||
Stable-sort within each bucket preserves enqueue order for messages
|
||||
typed in the same phase. Legacy ``PendingMessage`` objects with no
|
||||
``enqueued_at`` (written by older workers, defaulted to 0.0) sort as
|
||||
"before everything" — the pre-fix behaviour, which is a safe default
|
||||
for the rare queue entries that outlived a deploy.
|
||||
"""
|
||||
before: list[PendingMessage] = []
|
||||
after: list[PendingMessage] = []
|
||||
for pm in pending:
|
||||
if request_arrival_at > 0 and pm.enqueued_at > request_arrival_at:
|
||||
after.append(pm)
|
||||
else:
|
||||
before.append(pm)
|
||||
parts = pending_texts_from(before)
|
||||
if current_message and current_message.strip():
|
||||
parts.append(current_message)
|
||||
parts.extend(pending_texts_from(after))
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
def insert_pending_before_last(session: "ChatSession", texts: list[str]) -> None:
|
||||
"""Insert pending messages into *session* just before the last message.
|
||||
|
||||
Pending messages were queued during the previous turn, so they belong
|
||||
chronologically before the current user message that was already
|
||||
appended via ``maybe_append_user_message``. Inserting at ``len-1``
|
||||
preserves that order: [...history, pending_1, pending_2, current_msg].
|
||||
|
||||
The caller must have already appended the current user message before
|
||||
calling this function. If ``session.messages`` is unexpectedly empty,
|
||||
a warning is logged and the messages are appended at index 0 so they
|
||||
are not silently lost.
|
||||
"""
|
||||
if not texts:
|
||||
return
|
||||
if not session.messages:
|
||||
logger.warning(
|
||||
"insert_pending_before_last: session.messages is empty — "
|
||||
"current user message was not appended before drain; "
|
||||
"inserting pending messages at index 0"
|
||||
)
|
||||
insert_idx = max(0, len(session.messages) - 1)
|
||||
for i, content in enumerate(texts):
|
||||
session.messages.insert(
|
||||
insert_idx + i, ChatMessage(role="user", content=content)
|
||||
)
|
||||
|
||||
|
||||
async def persist_session_safe(
|
||||
session: "ChatSession", log_prefix: str = ""
|
||||
) -> "ChatSession":
|
||||
"""Persist *session* to the DB, returning the (possibly updated) session.
|
||||
|
||||
Swallows transient DB errors so a failing persist doesn't discard
|
||||
messages already popped from Redis — the turn continues from memory.
|
||||
"""
|
||||
try:
|
||||
return await upsert_chat_session(session)
|
||||
except Exception as err:
|
||||
logger.warning(
|
||||
"%s Failed to persist pending messages: %s",
|
||||
log_prefix or "pending_messages",
|
||||
err,
|
||||
)
|
||||
return session
|
||||
|
||||
|
||||
async def persist_pending_as_user_rows(
|
||||
session: "ChatSession",
|
||||
transcript_builder: "TranscriptBuilder",
|
||||
pending: list[PendingMessage],
|
||||
*,
|
||||
log_prefix: str,
|
||||
content_of: Callable[[PendingMessage], str] = lambda pm: pm.content,
|
||||
on_rollback: Callable[[int], None] | None = None,
|
||||
) -> bool:
|
||||
"""Append ``pending`` as user rows to *session* + *transcript_builder*,
|
||||
persist, and roll back + re-queue if the persist silently failed.
|
||||
|
||||
This is the shared mid-turn follow-up persist used by both the baseline
|
||||
and SDK paths — they differ only in (a) how they derive the displayed
|
||||
string from a ``PendingMessage`` and (b) what extra per-path state
|
||||
(e.g. ``openai_messages``) needs trimming on rollback. Those variance
|
||||
points are exposed as ``content_of`` and ``on_rollback``.
|
||||
|
||||
Flow:
|
||||
1. Snapshot transcript + record the session.messages length.
|
||||
2. Append one user row per pending message to both stores.
|
||||
3. ``persist_session_safe`` — swallowed errors mean no sequences get
|
||||
back-filled, which we use as the failure signal.
|
||||
4. If any newly-appended row has ``sequence is None`` → rollback:
|
||||
delete the appended rows, restore the transcript snapshot, call
|
||||
``on_rollback(anchor)`` for the caller's own state, then re-push
|
||||
each ``PendingMessage`` into the primary pending buffer so the
|
||||
next turn-start drain picks them up.
|
||||
|
||||
Returns ``True`` when the rows were persisted with sequences, ``False``
|
||||
when the rollback path fired. Callers can use this to decide whether
|
||||
to log success or continue a retry loop.
|
||||
"""
|
||||
if not pending:
|
||||
return True
|
||||
|
||||
session_anchor = len(session.messages)
|
||||
transcript_snapshot = transcript_builder.snapshot()
|
||||
|
||||
for pm in pending:
|
||||
content = content_of(pm)
|
||||
session.messages.append(ChatMessage(role="user", content=content))
|
||||
transcript_builder.append_user(content=content)
|
||||
|
||||
# ``persist_session_safe`` may return a ``model_copy`` of *session* (e.g.
|
||||
# when ``upsert_chat_session`` patches a concurrently-updated title).
|
||||
# Do NOT reassign the caller's reference — the caller already pushed the
|
||||
# rows into its own ``session.messages`` above, and rollback below MUST
|
||||
# delete from that same list. Inspect the returned object only to learn
|
||||
# whether sequences were back-filled; if so, copy them onto the caller's
|
||||
# objects so the session stays internally consistent for downstream
|
||||
# ``append_and_save_message`` calls.
|
||||
persisted = await persist_session_safe(session, log_prefix)
|
||||
persisted_tail = persisted.messages[session_anchor:]
|
||||
if len(persisted_tail) == len(pending) and all(
|
||||
m.sequence is not None for m in persisted_tail
|
||||
):
|
||||
for caller_msg, persisted_msg in zip(
|
||||
session.messages[session_anchor:], persisted_tail
|
||||
):
|
||||
caller_msg.sequence = persisted_msg.sequence
|
||||
newly_appended = session.messages[session_anchor:]
|
||||
|
||||
if any(m.sequence is None for m in newly_appended):
|
||||
logger.warning(
|
||||
"%s Mid-turn follow-up persist did not back-fill sequences; "
|
||||
"rolling back %d row(s) and re-queueing into the primary buffer",
|
||||
log_prefix,
|
||||
len(pending),
|
||||
)
|
||||
del session.messages[session_anchor:]
|
||||
transcript_builder.restore(transcript_snapshot)
|
||||
if on_rollback is not None:
|
||||
on_rollback(session_anchor)
|
||||
for pm in pending:
|
||||
try:
|
||||
await push_pending_message(session.session_id, pm)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s Failed to re-queue mid-turn follow-up on rollback",
|
||||
log_prefix,
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info(
|
||||
"%s Persisted %d mid-turn follow-up user row(s)",
|
||||
log_prefix,
|
||||
len(pending),
|
||||
)
|
||||
return True
|
||||
@@ -0,0 +1,472 @@
|
||||
"""Unit tests for pending_message_helpers."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot import pending_message_helpers as helpers_module
|
||||
from backend.copilot.pending_message_helpers import (
|
||||
PENDING_CALL_LIMIT,
|
||||
check_pending_call_rate,
|
||||
combine_pending_with_current,
|
||||
drain_pending_safe,
|
||||
insert_pending_before_last,
|
||||
persist_session_safe,
|
||||
)
|
||||
from backend.copilot.pending_messages import PendingMessage
|
||||
|
||||
# ── check_pending_call_rate ────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_pending_call_rate_returns_count(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
helpers_module, "get_redis_async", AsyncMock(return_value=MagicMock())
|
||||
)
|
||||
monkeypatch.setattr(helpers_module, "incr_with_ttl", AsyncMock(return_value=3))
|
||||
|
||||
result = await check_pending_call_rate("user-1")
|
||||
assert result == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_pending_call_rate_fails_open_on_redis_error(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
helpers_module,
|
||||
"get_redis_async",
|
||||
AsyncMock(side_effect=ConnectionError("down")),
|
||||
)
|
||||
|
||||
result = await check_pending_call_rate("user-1")
|
||||
assert result == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_pending_call_rate_at_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
helpers_module, "get_redis_async", AsyncMock(return_value=MagicMock())
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
helpers_module,
|
||||
"incr_with_ttl",
|
||||
AsyncMock(return_value=PENDING_CALL_LIMIT + 1),
|
||||
)
|
||||
|
||||
result = await check_pending_call_rate("user-1")
|
||||
assert result > PENDING_CALL_LIMIT
|
||||
|
||||
|
||||
# ── drain_pending_safe ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_pending_safe_returns_pending_messages(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""``drain_pending_safe`` now returns the structured ``PendingMessage``
|
||||
objects (not pre-formatted strings) so the auto-continue re-queue path
|
||||
can preserve ``file_ids`` / ``context`` on rollback."""
|
||||
msgs = [
|
||||
PendingMessage(content="hello", file_ids=["f1"]),
|
||||
PendingMessage(content="world"),
|
||||
]
|
||||
monkeypatch.setattr(
|
||||
helpers_module, "drain_pending_messages", AsyncMock(return_value=msgs)
|
||||
)
|
||||
|
||||
result = await drain_pending_safe("sess-1")
|
||||
assert result == msgs
|
||||
# Structured metadata survives — the bug r3105523410 guard.
|
||||
assert result[0].file_ids == ["f1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_pending_safe_returns_empty_on_error(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
helpers_module,
|
||||
"drain_pending_messages",
|
||||
AsyncMock(side_effect=RuntimeError("redis down")),
|
||||
)
|
||||
|
||||
result = await drain_pending_safe("sess-1", "[Test]")
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_pending_safe_empty_buffer(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
helpers_module, "drain_pending_messages", AsyncMock(return_value=[])
|
||||
)
|
||||
|
||||
result = await drain_pending_safe("sess-1")
|
||||
assert result == []
|
||||
|
||||
|
||||
# ── combine_pending_with_current ───────────────────────────────────────
|
||||
|
||||
|
||||
def test_combine_before_current_when_pending_older() -> None:
|
||||
"""Pending typed before the /stream request → goes ahead of current
|
||||
(prior-turn / inter-turn case)."""
|
||||
pending = [
|
||||
PendingMessage(content="older_a", enqueued_at=100.0),
|
||||
PendingMessage(content="older_b", enqueued_at=110.0),
|
||||
]
|
||||
result = combine_pending_with_current(
|
||||
pending, "current_msg", request_arrival_at=120.0
|
||||
)
|
||||
assert result == "older_a\n\nolder_b\n\ncurrent_msg"
|
||||
|
||||
|
||||
def test_combine_after_current_when_pending_newer() -> None:
|
||||
"""Pending queued AFTER the /stream request arrived → goes after
|
||||
current. This is the race path where user hits enter twice in quick
|
||||
succession (second press goes through the queue endpoint while the
|
||||
first /stream is still processing)."""
|
||||
pending = [
|
||||
PendingMessage(content="race_followup", enqueued_at=125.0),
|
||||
]
|
||||
result = combine_pending_with_current(
|
||||
pending, "current_msg", request_arrival_at=120.0
|
||||
)
|
||||
assert result == "current_msg\n\nrace_followup"
|
||||
|
||||
|
||||
def test_combine_mixed_before_and_after() -> None:
|
||||
"""Mixed bucket: older items first, current, then newer race items."""
|
||||
pending = [
|
||||
PendingMessage(content="way_older", enqueued_at=50.0),
|
||||
PendingMessage(content="race_fast_follow", enqueued_at=125.0),
|
||||
PendingMessage(content="also_older", enqueued_at=80.0),
|
||||
]
|
||||
result = combine_pending_with_current(
|
||||
pending, "current_msg", request_arrival_at=120.0
|
||||
)
|
||||
# Enqueue order preserved within each bucket (stable partition).
|
||||
assert result == "way_older\n\nalso_older\n\ncurrent_msg\n\nrace_fast_follow"
|
||||
|
||||
|
||||
def test_combine_no_current_joins_pending() -> None:
|
||||
"""Auto-continue case: no current message, just drained pending."""
|
||||
pending = [PendingMessage(content="a"), PendingMessage(content="b")]
|
||||
result = combine_pending_with_current(pending, None, request_arrival_at=0.0)
|
||||
assert result == "a\n\nb"
|
||||
|
||||
|
||||
def test_combine_legacy_zero_timestamp_sorts_before() -> None:
|
||||
"""A ``PendingMessage`` from before this field existed (default 0.0)
|
||||
should sort as "before everything" — safe pre-fix behaviour."""
|
||||
pending = [PendingMessage(content="legacy", enqueued_at=0.0)]
|
||||
result = combine_pending_with_current(
|
||||
pending, "current_msg", request_arrival_at=120.0
|
||||
)
|
||||
assert result == "legacy\n\ncurrent_msg"
|
||||
|
||||
|
||||
def test_combine_missing_request_arrival_falls_back_to_before() -> None:
|
||||
"""If the HTTP handler didn't stamp ``request_arrival_at`` (0.0
|
||||
default — older queue entries) the combine degrades gracefully to
|
||||
the pre-fix behaviour: all pending goes before current."""
|
||||
pending = [
|
||||
PendingMessage(content="a", enqueued_at=500.0),
|
||||
PendingMessage(content="b", enqueued_at=1000.0),
|
||||
]
|
||||
result = combine_pending_with_current(pending, "current", request_arrival_at=0.0)
|
||||
assert result == "a\n\nb\n\ncurrent"
|
||||
|
||||
|
||||
# ── insert_pending_before_last ─────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_session(*contents: str) -> Any:
|
||||
session = MagicMock()
|
||||
session.messages = [MagicMock(role="user", content=c) for c in contents]
|
||||
return session
|
||||
|
||||
|
||||
def test_insert_pending_before_last_single_existing_message() -> None:
|
||||
session = _make_session("current")
|
||||
insert_pending_before_last(session, ["queued"])
|
||||
assert session.messages[0].content == "queued"
|
||||
assert session.messages[1].content == "current"
|
||||
|
||||
|
||||
def test_insert_pending_before_last_multiple_pending() -> None:
|
||||
session = _make_session("current")
|
||||
insert_pending_before_last(session, ["p1", "p2"])
|
||||
contents = [m.content for m in session.messages]
|
||||
assert contents == ["p1", "p2", "current"]
|
||||
|
||||
|
||||
def test_insert_pending_before_last_empty_session() -> None:
|
||||
session = _make_session()
|
||||
insert_pending_before_last(session, ["queued"])
|
||||
assert session.messages[0].content == "queued"
|
||||
|
||||
|
||||
def test_insert_pending_before_last_no_texts_is_noop() -> None:
|
||||
session = _make_session("current")
|
||||
insert_pending_before_last(session, [])
|
||||
assert len(session.messages) == 1
|
||||
|
||||
|
||||
# ── persist_session_safe ───────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_session_safe_returns_updated_session(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
original = MagicMock()
|
||||
updated = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
helpers_module, "upsert_chat_session", AsyncMock(return_value=updated)
|
||||
)
|
||||
|
||||
result = await persist_session_safe(original, "[Test]")
|
||||
assert result is updated
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_session_safe_returns_original_on_error(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
original = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
helpers_module,
|
||||
"upsert_chat_session",
|
||||
AsyncMock(side_effect=Exception("db error")),
|
||||
)
|
||||
|
||||
result = await persist_session_safe(original, "[Test]")
|
||||
assert result is original
|
||||
|
||||
|
||||
# ── persist_pending_as_user_rows ───────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeTranscript:
|
||||
"""Minimal TranscriptBuilder shim — records append_user + snapshot/restore."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.entries: list[str] = []
|
||||
|
||||
def append_user(self, content: str, uuid: str | None = None) -> None:
|
||||
self.entries.append(content)
|
||||
|
||||
def snapshot(self) -> list[str]:
|
||||
return list(self.entries)
|
||||
|
||||
def restore(self, snap: list[str]) -> None:
|
||||
self.entries = list(snap)
|
||||
|
||||
|
||||
def _make_chat_message_class(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> Any:
|
||||
"""Return a simple ChatMessage stand-in that tracks sequence."""
|
||||
|
||||
class _Msg:
|
||||
def __init__(self, role: str, content: str) -> None:
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.sequence: int | None = None
|
||||
|
||||
monkeypatch.setattr(helpers_module, "ChatMessage", _Msg)
|
||||
return _Msg
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_pending_empty_list_is_noop(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from backend.copilot.pending_message_helpers import persist_pending_as_user_rows
|
||||
|
||||
_make_chat_message_class(monkeypatch)
|
||||
session = MagicMock()
|
||||
session.messages = []
|
||||
tb = _FakeTranscript()
|
||||
monkeypatch.setattr(helpers_module, "upsert_chat_session", AsyncMock())
|
||||
monkeypatch.setattr(helpers_module, "push_pending_message", AsyncMock())
|
||||
|
||||
ok = await persist_pending_as_user_rows(session, tb, [], log_prefix="[T]")
|
||||
assert ok is True
|
||||
assert session.messages == []
|
||||
assert tb.entries == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_pending_happy_path_appends_and_returns_true(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from backend.copilot.pending_message_helpers import persist_pending_as_user_rows
|
||||
from backend.copilot.pending_messages import PendingMessage as PM
|
||||
|
||||
_make_chat_message_class(monkeypatch)
|
||||
session = MagicMock()
|
||||
session.session_id = "sess"
|
||||
session.messages = []
|
||||
tb = _FakeTranscript()
|
||||
|
||||
async def _fake_upsert(sess: Any) -> Any:
|
||||
# Simulate the DB back-filling sequence numbers on success.
|
||||
for i, m in enumerate(sess.messages):
|
||||
m.sequence = i
|
||||
return sess
|
||||
|
||||
monkeypatch.setattr(helpers_module, "upsert_chat_session", _fake_upsert)
|
||||
push_mock = AsyncMock()
|
||||
monkeypatch.setattr(helpers_module, "push_pending_message", push_mock)
|
||||
|
||||
pending = [PM(content="a"), PM(content="b")]
|
||||
ok = await persist_pending_as_user_rows(session, tb, pending, log_prefix="[T]")
|
||||
assert ok is True
|
||||
assert [m.content for m in session.messages] == ["a", "b"]
|
||||
assert tb.entries == ["a", "b"]
|
||||
push_mock.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_pending_rollback_when_sequence_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from backend.copilot.pending_message_helpers import persist_pending_as_user_rows
|
||||
from backend.copilot.pending_messages import PendingMessage as PM
|
||||
|
||||
_make_chat_message_class(monkeypatch)
|
||||
session = MagicMock()
|
||||
session.session_id = "sess"
|
||||
# Prior state — anchor point is len(messages) before the helper runs.
|
||||
session.messages = []
|
||||
tb = _FakeTranscript()
|
||||
tb.entries = ["earlier-entry"]
|
||||
|
||||
async def _fake_upsert_fails_silently(sess: Any) -> Any:
|
||||
# Simulate the "persist swallowed the error" branch — sequences stay None.
|
||||
return sess
|
||||
|
||||
monkeypatch.setattr(
|
||||
helpers_module, "upsert_chat_session", _fake_upsert_fails_silently
|
||||
)
|
||||
push_mock = AsyncMock()
|
||||
monkeypatch.setattr(helpers_module, "push_pending_message", push_mock)
|
||||
|
||||
pending = [PM(content="a"), PM(content="b")]
|
||||
ok = await persist_pending_as_user_rows(session, tb, pending, log_prefix="[T]")
|
||||
|
||||
assert ok is False
|
||||
# Rollback: session.messages trimmed to anchor, transcript restored.
|
||||
assert session.messages == []
|
||||
assert tb.entries == ["earlier-entry"]
|
||||
# Both pending messages re-queued.
|
||||
assert push_mock.await_count == 2
|
||||
assert push_mock.await_args_list[0].args[1] is pending[0]
|
||||
assert push_mock.await_args_list[1].args[1] is pending[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_pending_rollback_calls_on_rollback_hook(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Baseline's openai_messages trim runs via the on_rollback hook."""
|
||||
from backend.copilot.pending_message_helpers import persist_pending_as_user_rows
|
||||
from backend.copilot.pending_messages import PendingMessage as PM
|
||||
|
||||
_make_chat_message_class(monkeypatch)
|
||||
session = MagicMock()
|
||||
session.session_id = "sess"
|
||||
session.messages = []
|
||||
tb = _FakeTranscript()
|
||||
|
||||
async def _fails(sess: Any) -> Any:
|
||||
return sess
|
||||
|
||||
monkeypatch.setattr(helpers_module, "upsert_chat_session", _fails)
|
||||
monkeypatch.setattr(helpers_module, "push_pending_message", AsyncMock())
|
||||
|
||||
on_rollback_calls: list[int] = []
|
||||
|
||||
def _on_rollback(anchor: int) -> None:
|
||||
on_rollback_calls.append(anchor)
|
||||
|
||||
await persist_pending_as_user_rows(
|
||||
session,
|
||||
tb,
|
||||
[PM(content="x")],
|
||||
log_prefix="[T]",
|
||||
on_rollback=_on_rollback,
|
||||
)
|
||||
assert on_rollback_calls == [0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_pending_uses_custom_content_of(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from backend.copilot.pending_message_helpers import persist_pending_as_user_rows
|
||||
from backend.copilot.pending_messages import PendingMessage as PM
|
||||
|
||||
_make_chat_message_class(monkeypatch)
|
||||
session = MagicMock()
|
||||
session.session_id = "sess"
|
||||
session.messages = []
|
||||
tb = _FakeTranscript()
|
||||
|
||||
async def _ok(sess: Any) -> Any:
|
||||
for i, m in enumerate(sess.messages):
|
||||
m.sequence = i
|
||||
return sess
|
||||
|
||||
monkeypatch.setattr(helpers_module, "upsert_chat_session", _ok)
|
||||
monkeypatch.setattr(helpers_module, "push_pending_message", AsyncMock())
|
||||
|
||||
await persist_pending_as_user_rows(
|
||||
session,
|
||||
tb,
|
||||
[PM(content="raw")],
|
||||
log_prefix="[T]",
|
||||
content_of=lambda pm: f"FORMATTED:{pm.content}",
|
||||
)
|
||||
assert session.messages[0].content == "FORMATTED:raw"
|
||||
assert tb.entries == ["FORMATTED:raw"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_pending_swallows_requeue_errors(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A broken push_pending_message on rollback must not raise upward —
|
||||
the rollback still needs to trim state even if re-queue fails."""
|
||||
from backend.copilot.pending_message_helpers import persist_pending_as_user_rows
|
||||
from backend.copilot.pending_messages import PendingMessage as PM
|
||||
|
||||
_make_chat_message_class(monkeypatch)
|
||||
session = MagicMock()
|
||||
session.session_id = "sess"
|
||||
session.messages = []
|
||||
tb = _FakeTranscript()
|
||||
|
||||
async def _fails(sess: Any) -> Any:
|
||||
return sess
|
||||
|
||||
monkeypatch.setattr(helpers_module, "upsert_chat_session", _fails)
|
||||
monkeypatch.setattr(
|
||||
helpers_module,
|
||||
"push_pending_message",
|
||||
AsyncMock(side_effect=RuntimeError("redis down")),
|
||||
)
|
||||
|
||||
ok = await persist_pending_as_user_rows(
|
||||
session, tb, [PM(content="x")], log_prefix="[T]"
|
||||
)
|
||||
# Still returns False (rolled back) — exception was logged + swallowed.
|
||||
assert ok is False
|
||||
450
autogpt_platform/backend/backend/copilot/pending_messages.py
Normal file
450
autogpt_platform/backend/backend/copilot/pending_messages.py
Normal file
@@ -0,0 +1,450 @@
|
||||
"""Pending-message buffer for in-flight copilot turns.
|
||||
|
||||
When a user sends a new message while a copilot turn is already executing,
|
||||
instead of blocking the frontend (or queueing a brand-new turn after the
|
||||
current one finishes), we want the new message to be *injected into the
|
||||
running turn* — appended between tool-call rounds so the model sees it
|
||||
before its next LLM call.
|
||||
|
||||
This module provides the cross-process buffer that makes that possible:
|
||||
|
||||
- **Producer** (chat API route): pushes a pending message to Redis and
|
||||
publishes a notification on a pub/sub channel.
|
||||
- **Consumer** (executor running the turn): on each tool-call round,
|
||||
drains the buffer and appends the pending messages to the conversation.
|
||||
|
||||
The Redis list is the durable store; the pub/sub channel is a fast
|
||||
wake-up hint for long-idle consumers (not used by default, but available
|
||||
for future blocking-wait semantics).
|
||||
|
||||
A hard cap of ``MAX_PENDING_MESSAGES`` per session prevents abuse. The
|
||||
buffer is trimmed to the latest ``MAX_PENDING_MESSAGES`` on every push.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.redis_helpers import capped_rpush
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Per-session cap. Higher values risk a runaway consumer; lower values
|
||||
# risk dropping user input under heavy typing. 10 was chosen as a
|
||||
# reasonable ceiling — a user typing faster than the copilot can drain
|
||||
# between tool rounds is already an unusual usage pattern.
|
||||
MAX_PENDING_MESSAGES = 10
|
||||
|
||||
# Redis key + TTL. The buffer is ephemeral: if a turn completes or the
|
||||
# executor dies, the pending messages should either have been drained
|
||||
# already or are safe to drop (the user can resend).
|
||||
_PENDING_KEY_PREFIX = "copilot:pending:"
|
||||
_PENDING_CHANNEL_PREFIX = "copilot:pending:notify:"
|
||||
_PENDING_TTL_SECONDS = 3600 # 1 hour — matches stream_ttl default
|
||||
|
||||
# Secondary queue that carries drained-but-awaiting-persist PendingMessages
|
||||
# from the MCP tool wrapper (which drains the primary buffer and injects
|
||||
# into tool output for the LLM) to sdk/service.py's _dispatch_response
|
||||
# handler for StreamToolOutputAvailable, which pops and persists them as a
|
||||
# separate user row chronologically after the tool_result row. This is the
|
||||
# hand-off between "Claude saw the follow-up mid-turn" (wrapper) and "UI
|
||||
# renders a user bubble for it" (service). Rollback path re-queues into
|
||||
# the PRIMARY buffer so the next turn-start drain picks them up if the
|
||||
# user-row persist fails.
|
||||
_PERSIST_QUEUE_KEY_PREFIX = "copilot:pending-persist:"
|
||||
|
||||
# Payload sent on the pub/sub notify channel. Subscribers treat any
|
||||
# message as a wake-up hint; the value itself is not meaningful.
|
||||
_NOTIFY_PAYLOAD = "1"
|
||||
|
||||
|
||||
class PendingMessageContext(BaseModel):
|
||||
"""Structured page context attached to a pending message.
|
||||
|
||||
Default ``extra='ignore'`` (pydantic's default): unknown keys from
|
||||
the loose HTTP-level ``StreamChatRequest.context: dict[str, str]``
|
||||
are silently dropped rather than raising ``ValidationError`` on
|
||||
forward-compat additions. The strict ``extra='forbid'`` mode was
|
||||
removed after sentry r3105553772 — strict validation at this
|
||||
boundary only added a 500 footgun; the upstream request model is
|
||||
already schemaless so strict mode protects nothing.
|
||||
"""
|
||||
|
||||
url: str | None = Field(default=None, max_length=2_000)
|
||||
content: str | None = Field(default=None, max_length=32_000)
|
||||
|
||||
|
||||
class PendingMessage(BaseModel):
|
||||
"""A user message queued for injection into an in-flight turn."""
|
||||
|
||||
content: str = Field(min_length=1, max_length=32_000)
|
||||
file_ids: list[str] = Field(default_factory=list, max_length=20)
|
||||
context: PendingMessageContext | None = None
|
||||
# Wall-clock time (unix seconds, float) the message was queued by the
|
||||
# user. Used by the turn-start drain to order pending relative to the
|
||||
# turn's ``current`` message: items typed *before* the current's
|
||||
# /stream arrival go ahead of it; items typed *after* (race path,
|
||||
# queued while the /stream HTTP request was still processing) go
|
||||
# after. Defaults to 0.0 for backward compatibility with entries
|
||||
# written before this field existed — those sort as "before everything"
|
||||
# which matches the pre-fix behaviour.
|
||||
enqueued_at: float = Field(default_factory=time.time)
|
||||
|
||||
|
||||
def _buffer_key(session_id: str) -> str:
|
||||
return f"{_PENDING_KEY_PREFIX}{session_id}"
|
||||
|
||||
|
||||
def _notify_channel(session_id: str) -> str:
|
||||
return f"{_PENDING_CHANNEL_PREFIX}{session_id}"
|
||||
|
||||
|
||||
def _decode_redis_item(item: Any) -> str:
|
||||
"""Decode a redis-py list item to a str.
|
||||
|
||||
redis-py returns ``bytes`` when ``decode_responses=False`` and ``str``
|
||||
when ``decode_responses=True``. This helper handles both so callers
|
||||
don't have to repeat the isinstance guard.
|
||||
"""
|
||||
return item.decode("utf-8") if isinstance(item, bytes) else str(item)
|
||||
|
||||
|
||||
async def push_pending_message(
|
||||
session_id: str,
|
||||
message: PendingMessage,
|
||||
) -> int:
|
||||
"""Append a pending message to the session's buffer.
|
||||
|
||||
Returns the new buffer length. Enforces ``MAX_PENDING_MESSAGES`` by
|
||||
trimming from the left (oldest) — the newest message always wins if
|
||||
the user has been typing faster than the copilot can drain.
|
||||
|
||||
Delegates to :func:`backend.data.redis_helpers.capped_rpush` so RPUSH
|
||||
+ LTRIM + EXPIRE + LLEN run atomically (MULTI/EXEC) in one round
|
||||
trip; a concurrent drain (LPOP) can no longer observe the list
|
||||
temporarily over ``MAX_PENDING_MESSAGES``.
|
||||
|
||||
Note on durability: if the executor turn crashes after a push but before
|
||||
the drain window runs, the message remains in Redis until the TTL expires
|
||||
(``_PENDING_TTL_SECONDS``, currently 1 hour). It is delivered on the
|
||||
next turn that drains the buffer. If no turn runs within the TTL the
|
||||
message is silently dropped; the user may resend it.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = _buffer_key(session_id)
|
||||
payload = message.model_dump_json()
|
||||
|
||||
new_length = await capped_rpush(
|
||||
redis,
|
||||
key,
|
||||
payload,
|
||||
max_len=MAX_PENDING_MESSAGES,
|
||||
ttl_seconds=_PENDING_TTL_SECONDS,
|
||||
)
|
||||
|
||||
# Fire-and-forget notify. Subscribers use this as a wake-up hint;
|
||||
# the buffer itself is authoritative so a lost notify is harmless.
|
||||
try:
|
||||
await redis.publish(_notify_channel(session_id), _NOTIFY_PAYLOAD)
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.warning("pending_messages: publish failed for %s: %s", session_id, e)
|
||||
|
||||
logger.info(
|
||||
"pending_messages: pushed message to session=%s (buffer_len=%d)",
|
||||
session_id,
|
||||
new_length,
|
||||
)
|
||||
return new_length
|
||||
|
||||
|
||||
async def drain_pending_messages(session_id: str) -> list[PendingMessage]:
|
||||
"""Atomically pop all pending messages for *session_id*.
|
||||
|
||||
Returns them in enqueue order (oldest first). Uses ``LPOP`` with a
|
||||
count so the read+delete is a single Redis round trip. If the list
|
||||
is empty or missing, returns ``[]``.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = _buffer_key(session_id)
|
||||
|
||||
# Redis LPOP with count (Redis 6.2+) returns None for missing key,
|
||||
# empty list if we somehow race an empty key, or the popped items.
|
||||
# Draining MAX_PENDING_MESSAGES at once is safe because the push side
|
||||
# uses RPUSH + LTRIM(-MAX_PENDING_MESSAGES, -1) to cap the list to that
|
||||
# same value, so the list can never hold more items than we drain here.
|
||||
# If the cap is raised on the push side, raise the drain count here too
|
||||
# (or switch to a loop drain).
|
||||
lpop_result = await redis.lpop(key, MAX_PENDING_MESSAGES) # type: ignore[assignment]
|
||||
if not lpop_result:
|
||||
return []
|
||||
raw_popped: list[Any] = list(lpop_result)
|
||||
|
||||
# redis-py may return bytes or str depending on decode_responses.
|
||||
decoded: list[str] = [_decode_redis_item(item) for item in raw_popped]
|
||||
|
||||
messages: list[PendingMessage] = []
|
||||
for payload in decoded:
|
||||
try:
|
||||
messages.append(PendingMessage.model_validate(json.loads(payload)))
|
||||
except (json.JSONDecodeError, ValidationError, TypeError, ValueError) as e:
|
||||
logger.warning(
|
||||
"pending_messages: dropping malformed entry for %s: %s",
|
||||
session_id,
|
||||
e,
|
||||
)
|
||||
|
||||
if messages:
|
||||
logger.info(
|
||||
"pending_messages: drained %d messages for session=%s",
|
||||
len(messages),
|
||||
session_id,
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
async def peek_pending_count(session_id: str) -> int:
|
||||
"""Return the current buffer length without consuming it."""
|
||||
redis = await get_redis_async()
|
||||
length = await cast("Any", redis.llen(_buffer_key(session_id)))
|
||||
return int(length)
|
||||
|
||||
|
||||
async def peek_pending_messages(session_id: str) -> list[PendingMessage]:
|
||||
"""Return pending messages without consuming them.
|
||||
|
||||
Uses LRANGE 0 -1 to read all items in enqueue order (oldest first)
|
||||
without removing them. Returns an empty list if the buffer is empty
|
||||
or the session has no pending messages.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = _buffer_key(session_id)
|
||||
items = await cast("Any", redis.lrange(key, 0, -1))
|
||||
if not items:
|
||||
return []
|
||||
messages: list[PendingMessage] = []
|
||||
for item in items:
|
||||
try:
|
||||
messages.append(
|
||||
PendingMessage.model_validate(json.loads(_decode_redis_item(item)))
|
||||
)
|
||||
except (json.JSONDecodeError, ValidationError, TypeError, ValueError) as e:
|
||||
logger.warning(
|
||||
"pending_messages: dropping malformed peek entry for %s: %s",
|
||||
session_id,
|
||||
e,
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
async def _clear_pending_messages_unsafe(session_id: str) -> None:
|
||||
"""Drop the session's pending buffer — **not** the normal turn cleanup.
|
||||
|
||||
Named ``_unsafe`` because reaching for this at turn end drops queued
|
||||
follow-ups on the floor instead of running them (the bug fixed by
|
||||
commit b64be73). The atomic ``LPOP`` drain at turn start is the
|
||||
primary consumer; anything pushed after the drain window belongs to
|
||||
the next turn by definition. Retained only as an operator/debug
|
||||
escape hatch for manually clearing a stuck session and as a fixture
|
||||
in the unit tests.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(_buffer_key(session_id))
|
||||
|
||||
|
||||
# Per-message and total-block caps for inline tool-boundary injection.
|
||||
# Per-message keeps a single long paste from dominating; the total cap
|
||||
# keeps the follow-up block small relative to the 100 KB MCP truncation
|
||||
# boundary so tool output always stays the larger share of the wrapper
|
||||
# return value.
|
||||
_FOLLOWUP_CONTENT_MAX_CHARS = 2_000
|
||||
_FOLLOWUP_TOTAL_MAX_CHARS = 6_000
|
||||
|
||||
|
||||
def _persist_queue_key(session_id: str) -> str:
|
||||
return f"{_PERSIST_QUEUE_KEY_PREFIX}{session_id}"
|
||||
|
||||
|
||||
async def stash_pending_for_persist(
|
||||
session_id: str,
|
||||
messages: list[PendingMessage],
|
||||
) -> None:
|
||||
"""Enqueue drained PendingMessages for UI-row persistence.
|
||||
|
||||
Writes each message as a JSON payload to
|
||||
``copilot:pending-persist:{session_id}``. The SDK service's
|
||||
tool-result dispatch handler LPOPs this queue right after appending
|
||||
the tool_result row to ``session.messages``, so the resulting user
|
||||
row lands at the correct chronological position (after the tool
|
||||
output the follow-up was drained against).
|
||||
|
||||
Fire-and-forget on Redis failures: a stash failure means Claude
|
||||
still saw the follow-up in tool output (the injection step ran
|
||||
first), so the only consequence is a missing UI bubble. Logged
|
||||
so it can be spotted.
|
||||
"""
|
||||
if not messages:
|
||||
return
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
key = _persist_queue_key(session_id)
|
||||
payloads = [m.model_dump_json() for m in messages]
|
||||
await redis.rpush(key, *payloads) # type: ignore[misc]
|
||||
await redis.expire(key, _PENDING_TTL_SECONDS) # type: ignore[misc]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"pending_messages: failed to stash %d message(s) for persist "
|
||||
"(session=%s); UI will miss the follow-up bubble but Claude "
|
||||
"already saw the content in tool output",
|
||||
len(messages),
|
||||
session_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def drain_pending_for_persist(session_id: str) -> list[PendingMessage]:
|
||||
"""Atomically drain the persist queue for *session_id*.
|
||||
|
||||
Returns the queued ``PendingMessage`` objects in enqueue order (oldest
|
||||
first). Returns ``[]`` on any error so the service-layer caller can
|
||||
always treat the result as a plain list. Called by sdk/service.py
|
||||
after appending a tool_result row to ``session.messages``.
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
key = _persist_queue_key(session_id)
|
||||
lpop_result = await redis.lpop( # type: ignore[assignment]
|
||||
key, MAX_PENDING_MESSAGES
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"pending_messages: drain_pending_for_persist failed for session=%s",
|
||||
session_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return []
|
||||
if not lpop_result:
|
||||
return []
|
||||
raw_popped: list[Any] = list(lpop_result)
|
||||
messages: list[PendingMessage] = []
|
||||
for item in raw_popped:
|
||||
try:
|
||||
messages.append(
|
||||
PendingMessage.model_validate(json.loads(_decode_redis_item(item)))
|
||||
)
|
||||
except (json.JSONDecodeError, ValidationError, TypeError, ValueError) as e:
|
||||
logger.warning(
|
||||
"pending_messages: dropping malformed persist-queue entry "
|
||||
"for %s: %s",
|
||||
session_id,
|
||||
e,
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
def format_pending_as_followup(pending: list[PendingMessage]) -> str:
|
||||
"""Render drained pending messages as a ``<user_follow_up>`` block.
|
||||
|
||||
Used by the SDK tool-boundary injection path to surface queued user
|
||||
text inside a tool result so the model reads it on the next LLM round,
|
||||
without starting a separate turn. Wrapped in a stable XML-style tag so
|
||||
the shared system-prompt supplement can teach the model to treat the
|
||||
contents as the user's continuation of their request, not as tool
|
||||
output. Each message is capped to keep the block bounded even if the
|
||||
user pastes long content.
|
||||
"""
|
||||
if not pending:
|
||||
return ""
|
||||
rendered: list[str] = []
|
||||
total_chars = 0
|
||||
dropped = 0
|
||||
for idx, pm in enumerate(pending, start=1):
|
||||
text = pm.content
|
||||
if len(text) > _FOLLOWUP_CONTENT_MAX_CHARS:
|
||||
text = text[:_FOLLOWUP_CONTENT_MAX_CHARS] + "… [truncated]"
|
||||
entry = f"Message {idx}:\n{text}"
|
||||
if pm.context and pm.context.url:
|
||||
entry += f"\n[Page URL: {pm.context.url}]"
|
||||
if pm.file_ids:
|
||||
entry += "\n[Attached files: " + ", ".join(pm.file_ids) + "]"
|
||||
if total_chars + len(entry) > _FOLLOWUP_TOTAL_MAX_CHARS:
|
||||
dropped = len(pending) - idx + 1
|
||||
break
|
||||
rendered.append(entry)
|
||||
total_chars += len(entry)
|
||||
if dropped:
|
||||
rendered.append(f"… [{dropped} more message(s) truncated]")
|
||||
body = "\n\n".join(rendered)
|
||||
return (
|
||||
"<user_follow_up>\n"
|
||||
"The user sent the following message(s) while this tool was running. "
|
||||
"Treat them as a continuation of their current request — acknowledge "
|
||||
"and act on them in your next response. Do not echo these tags back.\n\n"
|
||||
f"{body}\n"
|
||||
"</user_follow_up>"
|
||||
)
|
||||
|
||||
|
||||
async def drain_and_format_for_injection(
|
||||
session_id: str,
|
||||
*,
|
||||
log_prefix: str,
|
||||
) -> str:
|
||||
"""Drain the pending buffer and produce a ``<user_follow_up>`` block.
|
||||
|
||||
Shared entry point for every mid-turn injection site (``PostToolUse``
|
||||
hook for MCP + built-in tools, baseline between-rounds drain, etc.).
|
||||
Also stashes the drained messages on the persist queue so the service
|
||||
layer appends a real user row after the tool_result it rode in on —
|
||||
giving the UI a correctly-ordered bubble.
|
||||
|
||||
Returns an empty string if nothing was queued or Redis failed; callers
|
||||
can pass the result straight to ``additionalContext``.
|
||||
"""
|
||||
if not session_id:
|
||||
return ""
|
||||
try:
|
||||
pending = await drain_pending_messages(session_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"%s drain_pending_messages failed (session=%s); skipping injection",
|
||||
log_prefix,
|
||||
session_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return ""
|
||||
if not pending:
|
||||
return ""
|
||||
logger.info(
|
||||
"%s Injected %d user follow-up(s) into tool output (session=%s)",
|
||||
log_prefix,
|
||||
len(pending),
|
||||
session_id,
|
||||
)
|
||||
await stash_pending_for_persist(session_id, pending)
|
||||
return format_pending_as_followup(pending)
|
||||
|
||||
|
||||
def format_pending_as_user_message(message: PendingMessage) -> dict[str, Any]:
|
||||
"""Shape a ``PendingMessage`` into the OpenAI-format user message dict.
|
||||
|
||||
Used by the baseline tool-call loop when injecting the buffered
|
||||
message into the conversation. Context/file metadata (if any) is
|
||||
embedded into the content so the model sees everything in one block.
|
||||
"""
|
||||
parts: list[str] = [message.content]
|
||||
if message.context:
|
||||
if message.context.url:
|
||||
parts.append(f"\n\n[Page URL: {message.context.url}]")
|
||||
if message.context.content:
|
||||
parts.append(f"\n\n[Page content]\n{message.context.content}")
|
||||
if message.file_ids:
|
||||
parts.append(
|
||||
"\n\n[Attached files]\n"
|
||||
+ "\n".join(f"- file_id={fid}" for fid in message.file_ids)
|
||||
+ "\nUse read_workspace_file with the file_id to access file contents."
|
||||
)
|
||||
return {"role": "user", "content": "".join(parts)}
|
||||
@@ -0,0 +1,614 @@
|
||||
"""Tests for the copilot pending-messages buffer.
|
||||
|
||||
Uses a fake async Redis client so the tests don't require a real Redis
|
||||
instance (the backend test suite's DB/Redis fixtures are heavyweight
|
||||
and pull in the full app startup).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot import pending_messages as pm_module
|
||||
from backend.copilot.pending_messages import (
|
||||
MAX_PENDING_MESSAGES,
|
||||
PendingMessage,
|
||||
PendingMessageContext,
|
||||
_clear_pending_messages_unsafe,
|
||||
drain_and_format_for_injection,
|
||||
drain_pending_for_persist,
|
||||
drain_pending_messages,
|
||||
format_pending_as_followup,
|
||||
format_pending_as_user_message,
|
||||
peek_pending_count,
|
||||
peek_pending_messages,
|
||||
push_pending_message,
|
||||
stash_pending_for_persist,
|
||||
)
|
||||
|
||||
# ── Fake Redis ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
def __init__(self) -> None:
|
||||
# Values are ``str | bytes`` because real redis-py returns
|
||||
# bytes when ``decode_responses=False``; the drain path must
|
||||
# handle both and our tests exercise both.
|
||||
self.lists: dict[str, list[str | bytes]] = {}
|
||||
self.published: list[tuple[str, str]] = []
|
||||
|
||||
async def rpush(self, key: str, *values: Any) -> int:
|
||||
lst = self.lists.setdefault(key, [])
|
||||
lst.extend(values)
|
||||
return len(lst)
|
||||
|
||||
async def ltrim(self, key: str, start: int, stop: int) -> None:
|
||||
lst = self.lists.get(key, [])
|
||||
# Redis LTRIM stop is inclusive; -1 means the last element.
|
||||
if stop == -1:
|
||||
self.lists[key] = lst[start:]
|
||||
else:
|
||||
self.lists[key] = lst[start : stop + 1]
|
||||
|
||||
async def expire(self, key: str, seconds: int) -> int:
|
||||
# Fake doesn't enforce TTL — just acknowledge.
|
||||
return 1
|
||||
|
||||
async def publish(self, channel: str, payload: str) -> int:
|
||||
self.published.append((channel, payload))
|
||||
return 1
|
||||
|
||||
async def lpop(self, key: str, count: int) -> list[str | bytes] | None:
|
||||
lst = self.lists.get(key)
|
||||
if not lst:
|
||||
return None
|
||||
popped = lst[:count]
|
||||
self.lists[key] = lst[count:]
|
||||
return popped
|
||||
|
||||
async def llen(self, key: str) -> int:
|
||||
return len(self.lists.get(key, []))
|
||||
|
||||
async def lrange(self, key: str, start: int, stop: int) -> list[str | bytes]:
|
||||
lst = self.lists.get(key, [])
|
||||
# Redis LRANGE stop is inclusive; -1 means the last element.
|
||||
if stop == -1:
|
||||
return list(lst[start:])
|
||||
return list(lst[start : stop + 1])
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
if key in self.lists:
|
||||
del self.lists[key]
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def pipeline(self, transaction: bool = True) -> "_FakePipeline":
|
||||
# Returns a fake pipeline that records ops and replays them in
|
||||
# order on ``execute()``. Used by ``capped_rpush`` (push_pending_message)
|
||||
# and ``incr_with_ttl`` (call-rate check) via MULTI/EXEC.
|
||||
return _FakePipeline(self)
|
||||
|
||||
async def incr(self, key: str) -> int:
|
||||
# Used by incr_with_ttl's pipeline.
|
||||
current = int(self.lists.get(key, [0])[0]) if self.lists.get(key) else 0
|
||||
current += 1
|
||||
# We abuse the same lists dict for simple counters — store [count].
|
||||
self.lists[key] = [str(current)]
|
||||
return current
|
||||
|
||||
|
||||
class _FakePipeline:
|
||||
"""Async pipeline shim matching the redis-py MULTI/EXEC surface."""
|
||||
|
||||
def __init__(self, parent: "_FakeRedis") -> None:
|
||||
self._parent = parent
|
||||
self._ops: list[tuple[str, tuple[Any, ...], dict[str, Any]]] = []
|
||||
|
||||
# Each method just records the op; dispatching happens in execute().
|
||||
def rpush(self, key: str, *values: Any) -> "_FakePipeline":
|
||||
self._ops.append(("rpush", (key, *values), {}))
|
||||
return self
|
||||
|
||||
def ltrim(self, key: str, start: int, stop: int) -> "_FakePipeline":
|
||||
self._ops.append(("ltrim", (key, start, stop), {}))
|
||||
return self
|
||||
|
||||
def expire(self, key: str, seconds: int, **kw: Any) -> "_FakePipeline":
|
||||
self._ops.append(("expire", (key, seconds), kw))
|
||||
return self
|
||||
|
||||
def llen(self, key: str) -> "_FakePipeline":
|
||||
self._ops.append(("llen", (key,), {}))
|
||||
return self
|
||||
|
||||
def incr(self, key: str) -> "_FakePipeline":
|
||||
self._ops.append(("incr", (key,), {}))
|
||||
return self
|
||||
|
||||
async def execute(self) -> list[Any]:
|
||||
results: list[Any] = []
|
||||
for name, args, _kw in self._ops:
|
||||
fn = getattr(self._parent, name)
|
||||
results.append(await fn(*args))
|
||||
return results
|
||||
|
||||
# Support `async with pipeline() as pipe:` too.
|
||||
async def __aenter__(self) -> "_FakePipeline":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *a: Any) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fake_redis(monkeypatch: pytest.MonkeyPatch) -> _FakeRedis:
|
||||
redis = _FakeRedis()
|
||||
|
||||
async def _get_redis_async() -> _FakeRedis:
|
||||
return redis
|
||||
|
||||
monkeypatch.setattr(pm_module, "get_redis_async", _get_redis_async)
|
||||
return redis
|
||||
|
||||
|
||||
# ── Basic push / drain ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_and_drain_single_message(fake_redis: _FakeRedis) -> None:
|
||||
length = await push_pending_message("sess1", PendingMessage(content="hello"))
|
||||
assert length == 1
|
||||
assert await peek_pending_count("sess1") == 1
|
||||
|
||||
drained = await drain_pending_messages("sess1")
|
||||
assert len(drained) == 1
|
||||
assert drained[0].content == "hello"
|
||||
assert await peek_pending_count("sess1") == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_and_drain_preserves_order(fake_redis: _FakeRedis) -> None:
|
||||
for i in range(3):
|
||||
await push_pending_message("sess2", PendingMessage(content=f"msg {i}"))
|
||||
|
||||
drained = await drain_pending_messages("sess2")
|
||||
assert [m.content for m in drained] == ["msg 0", "msg 1", "msg 2"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_empty_returns_empty_list(fake_redis: _FakeRedis) -> None:
|
||||
assert await drain_pending_messages("nope") == []
|
||||
|
||||
|
||||
# ── Buffer cap ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cap_drops_oldest_when_exceeded(fake_redis: _FakeRedis) -> None:
|
||||
# Push MAX_PENDING_MESSAGES + 3 messages
|
||||
for i in range(MAX_PENDING_MESSAGES + 3):
|
||||
await push_pending_message("sess3", PendingMessage(content=f"m{i}"))
|
||||
|
||||
# Buffer should be clamped to MAX
|
||||
assert await peek_pending_count("sess3") == MAX_PENDING_MESSAGES
|
||||
|
||||
drained = await drain_pending_messages("sess3")
|
||||
assert len(drained) == MAX_PENDING_MESSAGES
|
||||
# Oldest 3 dropped — we should only see m3..m(MAX+2)
|
||||
assert drained[0].content == "m3"
|
||||
assert drained[-1].content == f"m{MAX_PENDING_MESSAGES + 2}"
|
||||
|
||||
|
||||
# ── Clear ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_removes_buffer(fake_redis: _FakeRedis) -> None:
|
||||
await push_pending_message("sess4", PendingMessage(content="x"))
|
||||
await push_pending_message("sess4", PendingMessage(content="y"))
|
||||
await _clear_pending_messages_unsafe("sess4")
|
||||
assert await peek_pending_count("sess4") == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_is_idempotent(fake_redis: _FakeRedis) -> None:
|
||||
# Clearing an already-empty buffer should not raise
|
||||
await _clear_pending_messages_unsafe("sess_empty")
|
||||
await _clear_pending_messages_unsafe("sess_empty")
|
||||
|
||||
|
||||
# ── Publish hook ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_publishes_notification(fake_redis: _FakeRedis) -> None:
|
||||
await push_pending_message("sess5", PendingMessage(content="hi"))
|
||||
assert ("copilot:pending:notify:sess5", "1") in fake_redis.published
|
||||
|
||||
|
||||
# ── Format helper ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_format_pending_plain_text() -> None:
|
||||
msg = PendingMessage(content="just text")
|
||||
out = format_pending_as_user_message(msg)
|
||||
assert out == {"role": "user", "content": "just text"}
|
||||
|
||||
|
||||
def test_format_pending_with_context_url() -> None:
|
||||
msg = PendingMessage(
|
||||
content="see this page",
|
||||
context=PendingMessageContext(url="https://example.com"),
|
||||
)
|
||||
out = format_pending_as_user_message(msg)
|
||||
content = out["content"]
|
||||
assert out["role"] == "user"
|
||||
assert "see this page" in content
|
||||
# The URL should appear verbatim in the [Page URL: ...] block.
|
||||
assert "[Page URL: https://example.com]" in content
|
||||
|
||||
|
||||
def test_format_pending_with_file_ids() -> None:
|
||||
msg = PendingMessage(content="look here", file_ids=["a", "b"])
|
||||
out = format_pending_as_user_message(msg)
|
||||
assert "file_id=a" in out["content"]
|
||||
assert "file_id=b" in out["content"]
|
||||
|
||||
|
||||
def test_format_pending_with_all_fields() -> None:
|
||||
"""All fields (content + context url/content + file_ids) should all appear."""
|
||||
msg = PendingMessage(
|
||||
content="summarise this",
|
||||
context=PendingMessageContext(
|
||||
url="https://example.com/page",
|
||||
content="headline text",
|
||||
),
|
||||
file_ids=["f1", "f2"],
|
||||
)
|
||||
out = format_pending_as_user_message(msg)
|
||||
body = out["content"]
|
||||
assert out["role"] == "user"
|
||||
assert "summarise this" in body
|
||||
assert "[Page URL: https://example.com/page]" in body
|
||||
assert "[Page content]\nheadline text" in body
|
||||
assert "file_id=f1" in body
|
||||
assert "file_id=f2" in body
|
||||
|
||||
|
||||
# ── Followup block caps ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_format_followup_single_message() -> None:
|
||||
out = format_pending_as_followup([PendingMessage(content="hello")])
|
||||
assert "<user_follow_up>" in out
|
||||
assert "</user_follow_up>" in out
|
||||
assert "Message 1:\nhello" in out
|
||||
|
||||
|
||||
def test_format_followup_total_cap_drops_overflow() -> None:
|
||||
"""10 × 2 KB messages must truncate past the total cap (~6 KB) with a
|
||||
marker indicating how many were dropped."""
|
||||
messages = [PendingMessage(content="A" * 2_000) for _ in range(10)]
|
||||
out = format_pending_as_followup(messages)
|
||||
# Block stays within the total cap (plus a little wrapper overhead).
|
||||
# The body alone is capped at 6 KB; we allow generous overhead for the
|
||||
# <user_follow_up> wrapper + headers.
|
||||
assert len(out) < 8_000
|
||||
assert "more message(s) truncated" in out
|
||||
# The first message at least must be present.
|
||||
assert "Message 1:" in out
|
||||
|
||||
|
||||
def test_format_followup_total_cap_marker_counts_dropped() -> None:
|
||||
"""The marker should name the exact number of dropped messages."""
|
||||
# Each 3 KB message gets capped to 2 KB first; with ~2 KB per entry and a
|
||||
# 6 KB total cap, roughly two entries fit and the rest are dropped.
|
||||
messages = [PendingMessage(content="X" * 3_000) for _ in range(5)]
|
||||
out = format_pending_as_followup(messages)
|
||||
assert "Message 1:" in out
|
||||
assert "Message 2:" in out
|
||||
# Message 3 would push total past 6 KB; marker should report exactly how
|
||||
# many were left out (here: messages 3, 4, 5 → 3 dropped).
|
||||
assert "[3 more message(s) truncated]" in out
|
||||
|
||||
|
||||
def test_format_followup_empty_returns_empty_string() -> None:
|
||||
assert format_pending_as_followup([]) == ""
|
||||
|
||||
|
||||
# ── Malformed payload handling ──────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_skips_malformed_entries(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
# Seed the fake with a mix of valid and malformed payloads
|
||||
fake_redis.lists["copilot:pending:bad"] = [
|
||||
json.dumps({"content": "valid"}),
|
||||
"{not valid json",
|
||||
json.dumps({"content": "also valid", "file_ids": ["a"]}),
|
||||
]
|
||||
drained = await drain_pending_messages("bad")
|
||||
assert len(drained) == 2
|
||||
assert drained[0].content == "valid"
|
||||
assert drained[1].content == "also valid"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_decodes_bytes_payloads(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
"""Real redis-py returns ``bytes`` when ``decode_responses=False``.
|
||||
|
||||
Seed the fake with bytes values to exercise the ``decode("utf-8")``
|
||||
branch in ``drain_pending_messages`` so a regression there doesn't
|
||||
slip past CI.
|
||||
"""
|
||||
fake_redis.lists["copilot:pending:bytes_sess"] = [
|
||||
json.dumps({"content": "from bytes"}).encode("utf-8"),
|
||||
]
|
||||
drained = await drain_pending_messages("bytes_sess")
|
||||
assert len(drained) == 1
|
||||
assert drained[0].content == "from bytes"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_peek_decodes_bytes_payloads(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
"""``peek_pending_messages`` uses the same ``_decode_redis_item`` helper
|
||||
as the drain path. Seed with bytes to guard against regression.
|
||||
"""
|
||||
fake_redis.lists["copilot:pending:peek_bytes_sess"] = [
|
||||
json.dumps({"content": "peeked from bytes"}).encode("utf-8"),
|
||||
]
|
||||
peeked = await peek_pending_messages("peek_bytes_sess")
|
||||
assert len(peeked) == 1
|
||||
assert peeked[0].content == "peeked from bytes"
|
||||
# peek must NOT consume the item
|
||||
assert fake_redis.lists["copilot:pending:peek_bytes_sess"] != []
|
||||
|
||||
|
||||
# ── Concurrency ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_push_and_drain(fake_redis: _FakeRedis) -> None:
|
||||
"""Two pushes fired concurrently should both land; a concurrent drain
|
||||
should see at least one of them (the fake serialises, so it will
|
||||
always see both, but we exercise the code path either way)."""
|
||||
await asyncio.gather(
|
||||
push_pending_message("sess_conc", PendingMessage(content="a")),
|
||||
push_pending_message("sess_conc", PendingMessage(content="b")),
|
||||
)
|
||||
drained = await drain_pending_messages("sess_conc")
|
||||
assert len(drained) >= 1
|
||||
contents = {m.content for m in drained}
|
||||
assert contents <= {"a", "b"}
|
||||
|
||||
|
||||
# ── Publish error path ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_survives_publish_failure(
|
||||
fake_redis: _FakeRedis, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""A publish error must not propagate — the buffer is still authoritative."""
|
||||
|
||||
async def _fail_publish(channel: str, payload: str) -> int:
|
||||
raise RuntimeError("redis publish down")
|
||||
|
||||
monkeypatch.setattr(fake_redis, "publish", _fail_publish)
|
||||
|
||||
length = await push_pending_message("sess_pub_err", PendingMessage(content="ok"))
|
||||
assert length == 1
|
||||
drained = await drain_pending_messages("sess_pub_err")
|
||||
assert len(drained) == 1
|
||||
assert drained[0].content == "ok"
|
||||
|
||||
|
||||
# ── peek_pending_messages ────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_peek_pending_messages_returns_all_without_consuming(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
"""Peek returns all queued messages and leaves the buffer intact."""
|
||||
await push_pending_message("peek1", PendingMessage(content="first"))
|
||||
await push_pending_message("peek1", PendingMessage(content="second"))
|
||||
|
||||
peeked = await peek_pending_messages("peek1")
|
||||
assert len(peeked) == 2
|
||||
assert peeked[0].content == "first"
|
||||
assert peeked[1].content == "second"
|
||||
|
||||
# Buffer must not be consumed — count still 2
|
||||
assert await peek_pending_count("peek1") == 2
|
||||
drained = await drain_pending_messages("peek1")
|
||||
assert len(drained) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_peek_pending_messages_empty_buffer(fake_redis: _FakeRedis) -> None:
|
||||
"""Peek on a missing key returns an empty list without raising."""
|
||||
result = await peek_pending_messages("no_such_session")
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_peek_pending_messages_decodes_bytes_payloads(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
"""peek_pending_messages decodes bytes entries the same way drain does."""
|
||||
fake_redis.lists["copilot:pending:peek_bytes"] = [
|
||||
json.dumps({"content": "from bytes"}).encode("utf-8"),
|
||||
]
|
||||
peeked = await peek_pending_messages("peek_bytes")
|
||||
assert len(peeked) == 1
|
||||
assert peeked[0].content == "from bytes"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_peek_pending_messages_skips_malformed_entries(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
"""Malformed entries are skipped and valid ones are returned."""
|
||||
fake_redis.lists["copilot:pending:peek_bad"] = [
|
||||
json.dumps({"content": "valid peek"}),
|
||||
"{bad json",
|
||||
json.dumps({"content": "also valid peek"}),
|
||||
]
|
||||
peeked = await peek_pending_messages("peek_bad")
|
||||
assert len(peeked) == 2
|
||||
assert peeked[0].content == "valid peek"
|
||||
assert peeked[1].content == "also valid peek"
|
||||
|
||||
|
||||
# ── Persist queue (mid-turn follow-up UI bubble hand-off) ───────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stash_for_persist_enqueues_and_drain_pops_in_order(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
"""stash_pending_for_persist writes messages under the persist key;
|
||||
drain_pending_for_persist LPOPs them in enqueue order."""
|
||||
msgs = [
|
||||
PendingMessage(content="first mid-turn follow-up"),
|
||||
PendingMessage(content="second"),
|
||||
]
|
||||
await stash_pending_for_persist("sess-persist", msgs)
|
||||
|
||||
# Stored under the distinct persist key, NOT the primary buffer.
|
||||
assert "copilot:pending-persist:sess-persist" in fake_redis.lists
|
||||
assert "copilot:pending:sess-persist" not in fake_redis.lists
|
||||
|
||||
drained = await drain_pending_for_persist("sess-persist")
|
||||
assert len(drained) == 2
|
||||
assert drained[0].content == "first mid-turn follow-up"
|
||||
assert drained[1].content == "second"
|
||||
|
||||
# Queue is empty after drain.
|
||||
assert await drain_pending_for_persist("sess-persist") == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stash_for_persist_empty_list_is_noop(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
"""Passing an empty list must NOT create a Redis key (would leak
|
||||
empty persist entries and require a drain for no reason)."""
|
||||
await stash_pending_for_persist("sess-noop", [])
|
||||
assert "copilot:pending-persist:sess-noop" not in fake_redis.lists
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_pending_for_persist_missing_key_returns_empty(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
assert await drain_pending_for_persist("never-stashed") == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_pending_for_persist_skips_malformed(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
fake_redis.lists["copilot:pending-persist:bad"] = [
|
||||
json.dumps({"content": "good one"}),
|
||||
"not json",
|
||||
json.dumps({"content": "another good one"}),
|
||||
]
|
||||
result = await drain_pending_for_persist("bad")
|
||||
assert [m.content for m in result] == ["good one", "another good one"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_queue_isolated_from_primary_buffer(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
"""Draining the persist queue must NOT touch the primary pending
|
||||
buffer (and vice versa) — they serve different lifecycles."""
|
||||
# Seed the primary buffer with one entry.
|
||||
await push_pending_message("sess-iso", PendingMessage(content="primary"))
|
||||
# Stash a separate entry on the persist queue.
|
||||
await stash_pending_for_persist("sess-iso", [PendingMessage(content="persist")])
|
||||
|
||||
drained_persist = await drain_pending_for_persist("sess-iso")
|
||||
assert [m.content for m in drained_persist] == ["persist"]
|
||||
|
||||
# Primary buffer untouched.
|
||||
assert await peek_pending_count("sess-iso") == 1
|
||||
drained_primary = await drain_pending_messages("sess-iso")
|
||||
assert [m.content for m in drained_primary] == ["primary"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stash_for_persist_swallows_redis_failure(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A broken Redis during stash must not raise — Claude has already
|
||||
seen the follow-up via tool output; the only fallout is a missing
|
||||
UI bubble, which we log and move on."""
|
||||
|
||||
async def _broken_redis() -> Any:
|
||||
raise ConnectionError("redis down")
|
||||
|
||||
monkeypatch.setattr(pm_module, "get_redis_async", _broken_redis)
|
||||
|
||||
# Must NOT raise.
|
||||
await stash_pending_for_persist("sess-broken", [PendingMessage(content="lost")])
|
||||
|
||||
|
||||
# ── drain_and_format_for_injection: shared entry point ─────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_and_format_for_injection_happy_path(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
"""Queued messages drain into a ready-to-inject <user_follow_up> block
|
||||
AND are stashed on the persist queue for UI row hand-off."""
|
||||
await push_pending_message("sess-share", PendingMessage(content="do X also"))
|
||||
|
||||
result = await drain_and_format_for_injection("sess-share", log_prefix="[TEST]")
|
||||
|
||||
assert "<user_follow_up>" in result
|
||||
assert "do X also" in result
|
||||
# Primary buffer drained.
|
||||
assert await peek_pending_count("sess-share") == 0
|
||||
# Persist queue got a copy for the UI.
|
||||
persisted = await drain_pending_for_persist("sess-share")
|
||||
assert len(persisted) == 1
|
||||
assert persisted[0].content == "do X also"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_and_format_for_injection_empty_returns_empty(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
assert await drain_and_format_for_injection("sess-empty", log_prefix="[TEST]") == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_and_format_for_injection_swallows_redis_error(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _broken() -> Any:
|
||||
raise ConnectionError("down")
|
||||
|
||||
monkeypatch.setattr(pm_module, "get_redis_async", _broken)
|
||||
|
||||
# Must NOT raise — broken Redis becomes "nothing to inject".
|
||||
assert (
|
||||
await drain_and_format_for_injection("sess-broken", log_prefix="[TEST]") == ""
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_and_format_for_injection_missing_session_id() -> None:
|
||||
assert await drain_and_format_for_injection("", log_prefix="[TEST]") == ""
|
||||
@@ -87,6 +87,7 @@ ToolName = Literal[
|
||||
"get_agent_building_guide",
|
||||
"get_doc_page",
|
||||
"get_mcp_guide",
|
||||
"get_sub_session_result",
|
||||
"list_folders",
|
||||
"list_workspace_files",
|
||||
"memory_forget_confirm",
|
||||
@@ -99,6 +100,7 @@ ToolName = Literal[
|
||||
"run_agent",
|
||||
"run_block",
|
||||
"run_mcp_tool",
|
||||
"run_sub_session",
|
||||
"search_docs",
|
||||
"search_feature_requests",
|
||||
"update_folder",
|
||||
|
||||
@@ -8,11 +8,12 @@ handling the distinction between:
|
||||
|
||||
from functools import cache
|
||||
|
||||
from backend.blocks.autopilot import AUTOPILOT_BLOCK_ID
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
# Shared technical notes that apply to both SDK and baseline modes
|
||||
_SHARED_TOOL_NOTES = f"""\
|
||||
# Workflow rules appended to the system prompt on every copilot turn
|
||||
# (baseline appends directly; SDK appends via the storage-supplement
|
||||
# template). These are cross-tool rules (file sharing, @@agptfile: refs,
|
||||
# tool-discovery priority, sub-agent etiquette) that don't belong on any
|
||||
# individual tool schema.
|
||||
SHARED_TOOL_NOTES = """\
|
||||
|
||||
### Sharing files
|
||||
After `write_workspace_file`, embed the `download_url` in Markdown:
|
||||
@@ -68,13 +69,13 @@ that would be corrupted by text encoding.
|
||||
|
||||
Example — committing an image file to GitHub:
|
||||
```json
|
||||
{{
|
||||
"files": [{{
|
||||
{
|
||||
"files": [{
|
||||
"path": "docs/hero.png",
|
||||
"content": "workspace://abc123#image/png",
|
||||
"operation": "upsert"
|
||||
}}]
|
||||
}}
|
||||
}]
|
||||
}
|
||||
```
|
||||
|
||||
### Writing large files — CRITICAL (causes production failures)
|
||||
@@ -149,20 +150,27 @@ When the user asks to interact with a service or API, follow this order:
|
||||
All tasks must run in the foreground.
|
||||
|
||||
### Delegating to another autopilot (sub-autopilot pattern)
|
||||
Use the **AutoPilotBlock** (`run_block` with block_id
|
||||
`{AUTOPILOT_BLOCK_ID}`) to delegate a task to a fresh
|
||||
autopilot instance. The sub-autopilot has its own full tool set and can
|
||||
perform multi-step work autonomously.
|
||||
Use the **`run_sub_session`** tool to delegate a task to a fresh
|
||||
sub-AutoPilot. The sub has its own full tool set and can perform
|
||||
multi-step work autonomously.
|
||||
|
||||
- **Input**: `prompt` (required) — the task description.
|
||||
Optional: `system_context` to constrain behavior, `session_id` to
|
||||
continue a previous conversation, `max_recursion_depth` (default 3).
|
||||
- **Output**: `response` (text), `tool_calls` (list), `session_id`
|
||||
(for continuation), `conversation_history`, `token_usage`.
|
||||
- `prompt` (required): the task description.
|
||||
- `system_context` (optional): extra context prepended to the prompt.
|
||||
- `sub_autopilot_session_id` (optional): continue an existing
|
||||
sub-AutoPilot — pass the `sub_autopilot_session_id` returned by a
|
||||
previous completed run.
|
||||
- `wait_for_result` (default 60, max 300): seconds to wait inline. If
|
||||
the sub isn't done by then you get `status="running"` + a
|
||||
`sub_session_id` — call **`get_sub_session_result`** with that id
|
||||
(wait up to 300s more per call) until it returns `completed` or
|
||||
`error`. Works across turns — safe to reconnect in a later message.
|
||||
|
||||
Use this when a task is complex enough to benefit from a separate
|
||||
autopilot context, e.g. "research X and write a report" while the
|
||||
parent autopilot handles orchestration.
|
||||
parent autopilot handles orchestration. Do NOT invoke `AutoPilotBlock`
|
||||
via `run_block` — it's hidden from `run_block` by design because the
|
||||
dedicated tool handles the async lifecycle correctly.
|
||||
|
||||
"""
|
||||
|
||||
# E2B-only notes — E2B has full internet access so gh CLI works there.
|
||||
@@ -174,6 +182,7 @@ sandbox so `bash_exec` can access it for further processing.
|
||||
The exact sandbox path is shown in the `[Sandbox copy available at ...]` note.
|
||||
|
||||
### GitHub CLI (`gh`) and git
|
||||
- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before asking them to connect it.
|
||||
- If the user has connected their GitHub account, both `gh` and `git` are
|
||||
pre-authenticated — use them directly without any manual login step.
|
||||
`git` HTTPS operations (clone, push, pull) work automatically.
|
||||
@@ -254,7 +263,7 @@ When a tool output contains `<tool-output-truncated workspace_path="...">`, the
|
||||
full output is in workspace storage (NOT on the local filesystem). To access it:
|
||||
- Use `read_workspace_file(path="...", offset=..., length=50000)` for reading sections.
|
||||
- To process in the sandbox, use `read_workspace_file(path="...", save_to_path="{working_dir}/file.json")` first, then use `bash_exec` on the local copy.
|
||||
{_SHARED_TOOL_NOTES}{extra_notes}"""
|
||||
{SHARED_TOOL_NOTES}{extra_notes}"""
|
||||
|
||||
|
||||
# Pre-built supplements for common environments
|
||||
@@ -305,33 +314,37 @@ def _get_cloud_sandbox_supplement() -> str:
|
||||
)
|
||||
|
||||
|
||||
def _generate_tool_documentation() -> str:
|
||||
"""Auto-generate tool documentation from TOOL_REGISTRY.
|
||||
_USER_FOLLOW_UP_NOTE = """
|
||||
# `<user_follow_up>` blocks in tool output
|
||||
|
||||
NOTE: This is ONLY used in baseline mode (direct OpenAI API).
|
||||
SDK mode doesn't need it since Claude gets tool schemas automatically.
|
||||
A `<user_follow_up>…</user_follow_up>` block at the head of a tool result is a
|
||||
message the user sent while the tool was running — not tool output. The user is
|
||||
watching the chat live and waiting for confirmation their message landed.
|
||||
|
||||
This generates a complete list of available tools with their descriptions,
|
||||
ensuring the documentation stays in sync with the actual tool implementations.
|
||||
All workflow guidance is now embedded in individual tool descriptions.
|
||||
Every time you see one:
|
||||
|
||||
Only documents tools that are available in the current environment
|
||||
(checked via tool.is_available property).
|
||||
"""
|
||||
docs = "\n## AVAILABLE TOOLS\n\n"
|
||||
1. **Ack immediately.** Your very next emission must be a short visible line,
|
||||
before any more tool calls:
|
||||
*"Got your follow-up: {paraphrase}. {what I'll do}."*
|
||||
|
||||
# Sort tools alphabetically for consistent output
|
||||
# Filter by is_available to match get_available_tools() behavior
|
||||
for name in sorted(TOOL_REGISTRY.keys()):
|
||||
tool = TOOL_REGISTRY[name]
|
||||
if not tool.is_available:
|
||||
continue
|
||||
schema = tool.as_openai_tool()
|
||||
desc = schema["function"].get("description", "No description available")
|
||||
# Format as bullet list with tool name in code style
|
||||
docs += f"- **`{name}`**: {desc}\n"
|
||||
2. **Then act on it:**
|
||||
- Question/input request → stop the tool chain and answer/ask back.
|
||||
- New requirement → fold into the current plan.
|
||||
- Correction → update the plan and continue with the revised target.
|
||||
|
||||
return docs
|
||||
Never echo the `<user_follow_up>` tags back. The block holds only the user's
|
||||
words — the rest of the tool result is the real data.
|
||||
|
||||
# Always close the turn with visible text
|
||||
|
||||
Every turn MUST end with at least one short user-facing text sentence —
|
||||
even if it is only "Done." or "I'm stopping here because X." Never end a
|
||||
turn with only tool calls or only thinking. The user's UI renders text
|
||||
messages; a turn that emits only thinking blocks or only tool calls shows
|
||||
up as a frozen screen with no response. If your plan was to stop after
|
||||
the last tool result, still produce one closing sentence summarising
|
||||
what happened so the user knows the turn is complete.
|
||||
"""
|
||||
|
||||
|
||||
@cache
|
||||
@@ -356,9 +369,12 @@ def get_sdk_supplement(use_e2b: bool) -> str:
|
||||
Returns:
|
||||
The supplement string to append to the system prompt
|
||||
"""
|
||||
if use_e2b:
|
||||
return _get_cloud_sandbox_supplement()
|
||||
return _get_local_storage_supplement("/tmp/copilot-<session-id>")
|
||||
base = (
|
||||
_get_cloud_sandbox_supplement()
|
||||
if use_e2b
|
||||
else _get_local_storage_supplement("/tmp/copilot-<session-id>")
|
||||
)
|
||||
return base + _USER_FOLLOW_UP_NOTE
|
||||
|
||||
|
||||
def get_graphiti_supplement() -> str:
|
||||
@@ -395,17 +411,3 @@ You have access to persistent temporal memory tools that remember facts across s
|
||||
- group_id is handled automatically by the system — never set it yourself.
|
||||
- When storing, be specific about operational rules and instructions (e.g., "CC Sarah on client communications" not just "Sarah is the assistant").
|
||||
"""
|
||||
|
||||
|
||||
def get_baseline_supplement() -> str:
|
||||
"""Get the supplement for baseline mode (direct OpenAI API).
|
||||
|
||||
Baseline mode INCLUDES auto-generated tool documentation because the
|
||||
direct API doesn't automatically provide tool schemas to Claude.
|
||||
Also includes shared technical notes (but NOT SDK-specific environment details).
|
||||
|
||||
Returns:
|
||||
The supplement string to append to the system prompt
|
||||
"""
|
||||
tool_docs = _generate_tool_documentation()
|
||||
return tool_docs + _SHARED_TOOL_NOTES
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
"""CoPilot rate limiting based on token usage.
|
||||
"""CoPilot rate limiting based on generation cost.
|
||||
|
||||
Uses Redis fixed-window counters to track per-user token consumption
|
||||
with configurable daily and weekly limits. Daily windows reset at
|
||||
midnight UTC; weekly windows reset at ISO week boundary (Monday 00:00
|
||||
UTC). Fails open when Redis is unavailable to avoid blocking users.
|
||||
Uses Redis fixed-window counters to track per-user USD spend (stored as
|
||||
microdollars, matching ``PlatformCostLog.cost_microdollars``) with
|
||||
configurable daily and weekly limits. Daily windows reset at midnight UTC;
|
||||
weekly windows reset at ISO week boundary (Monday 00:00 UTC). Fails open
|
||||
when Redis is unavailable to avoid blocking users.
|
||||
|
||||
Storing microdollars rather than tokens means the counter already reflects
|
||||
real model pricing (including cache discounts and provider surcharges), so
|
||||
this module carries no pricing table — the cost comes from OpenRouter's
|
||||
``usage.cost`` field (baseline) or the Claude Agent SDK's reported total
|
||||
cost (SDK path).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -17,12 +24,15 @@ from redis.exceptions import RedisError
|
||||
|
||||
from backend.data.db_accessors import user_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.util.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis key prefixes
|
||||
_USAGE_KEY_PREFIX = "copilot:usage"
|
||||
# Redis key prefixes. Bumped from "copilot:usage" (token-based) to
|
||||
# "copilot:cost" on the token→cost migration so stale counters do not
|
||||
# get misinterpreted as microdollars (which would dramatically under-count).
|
||||
_USAGE_KEY_PREFIX = "copilot:cost"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -31,7 +41,7 @@ _USAGE_KEY_PREFIX = "copilot:usage"
|
||||
|
||||
|
||||
class SubscriptionTier(str, Enum):
|
||||
"""Subscription tiers with increasing token allowances.
|
||||
"""Subscription tiers with increasing cost allowances.
|
||||
|
||||
Mirrors the ``SubscriptionTier`` enum in ``schema.prisma``.
|
||||
Once ``prisma generate`` is run, this can be replaced with::
|
||||
@@ -45,9 +55,9 @@ class SubscriptionTier(str, Enum):
|
||||
ENTERPRISE = "ENTERPRISE"
|
||||
|
||||
|
||||
# Multiplier applied to the base limits (from LD / config) for each tier.
|
||||
# Intentionally int (not float): keeps limits as whole token counts and avoids
|
||||
# floating-point rounding. If fractional multipliers are ever needed, change
|
||||
# Multiplier applied to the base cost limits (from LD / config) for each tier.
|
||||
# Intentionally int (not float): keeps limits as whole microdollars and avoids
|
||||
# floating-point rounding. If fractional multipliers are ever needed, change
|
||||
# the type and round the result in get_global_rate_limits().
|
||||
TIER_MULTIPLIERS: dict[SubscriptionTier, int] = {
|
||||
SubscriptionTier.FREE: 1,
|
||||
@@ -60,17 +70,27 @@ DEFAULT_TIER = SubscriptionTier.FREE
|
||||
|
||||
|
||||
class UsageWindow(BaseModel):
|
||||
"""Usage within a single time window."""
|
||||
"""Usage within a single time window.
|
||||
|
||||
``used`` and ``limit`` are in microdollars (1 USD = 1_000_000).
|
||||
"""
|
||||
|
||||
used: int
|
||||
limit: int = Field(
|
||||
description="Maximum tokens allowed in this window. 0 means unlimited."
|
||||
description="Maximum microdollars of spend allowed in this window. "
|
||||
"0 means unlimited."
|
||||
)
|
||||
resets_at: datetime
|
||||
|
||||
|
||||
class CoPilotUsageStatus(BaseModel):
|
||||
"""Current usage status for a user across all windows."""
|
||||
"""Current usage status for a user across all windows.
|
||||
|
||||
Internal representation used by server-side code that needs to compare
|
||||
usage against limits (e.g. the reset-credits endpoint). The public API
|
||||
returns ``CoPilotUsagePublic`` instead so that raw spend and limit
|
||||
figures never leak to clients.
|
||||
"""
|
||||
|
||||
daily: UsageWindow
|
||||
weekly: UsageWindow
|
||||
@@ -81,6 +101,68 @@ class CoPilotUsageStatus(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class UsageWindowPublic(BaseModel):
|
||||
"""Public view of a usage window — only the percentage and reset time.
|
||||
|
||||
Hides the raw spend and the cap so clients cannot derive per-turn cost
|
||||
or reverse-engineer platform margins. ``percent_used`` is capped at 100.
|
||||
"""
|
||||
|
||||
percent_used: float = Field(
|
||||
ge=0.0,
|
||||
le=100.0,
|
||||
description="Percentage of the window's allowance used (0-100). "
|
||||
"Clamped at 100 when over the cap.",
|
||||
)
|
||||
resets_at: datetime
|
||||
|
||||
|
||||
class CoPilotUsagePublic(BaseModel):
|
||||
"""Current usage status for a user — public (client-safe) shape."""
|
||||
|
||||
daily: UsageWindowPublic | None = Field(
|
||||
default=None,
|
||||
description="Null when no daily cap is configured (unlimited).",
|
||||
)
|
||||
weekly: UsageWindowPublic | None = Field(
|
||||
default=None,
|
||||
description="Null when no weekly cap is configured (unlimited).",
|
||||
)
|
||||
tier: SubscriptionTier = DEFAULT_TIER
|
||||
reset_cost: int = Field(
|
||||
default=0,
|
||||
description="Credit cost (in cents) to reset the daily limit. 0 = feature disabled.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_status(cls, status: CoPilotUsageStatus) -> "CoPilotUsagePublic":
|
||||
"""Project the internal status onto the client-safe schema."""
|
||||
|
||||
def window(w: UsageWindow) -> UsageWindowPublic | None:
|
||||
if w.limit <= 0:
|
||||
return None
|
||||
# When at/over the cap, snap to exactly 100.0 so the UI's
|
||||
# rounded display and its exhaustion check (`percent_used >= 100`)
|
||||
# agree. Without this, e.g. 99.95% would render as "100% used"
|
||||
# via Math.round but fail the exhaustion check, leaving the
|
||||
# reset button hidden while the bar appears full.
|
||||
if w.used >= w.limit:
|
||||
pct = 100.0
|
||||
else:
|
||||
pct = round(100.0 * w.used / w.limit, 1)
|
||||
return UsageWindowPublic(
|
||||
percent_used=pct,
|
||||
resets_at=w.resets_at,
|
||||
)
|
||||
|
||||
return cls(
|
||||
daily=window(status.daily),
|
||||
weekly=window(status.weekly),
|
||||
tier=status.tier,
|
||||
reset_cost=status.reset_cost,
|
||||
)
|
||||
|
||||
|
||||
class RateLimitExceeded(Exception):
|
||||
"""Raised when a user exceeds their CoPilot usage limit."""
|
||||
|
||||
@@ -102,8 +184,8 @@ class RateLimitExceeded(Exception):
|
||||
|
||||
async def get_usage_status(
|
||||
user_id: str,
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
daily_cost_limit: int,
|
||||
weekly_cost_limit: int,
|
||||
rate_limit_reset_cost: int = 0,
|
||||
tier: SubscriptionTier = DEFAULT_TIER,
|
||||
) -> CoPilotUsageStatus:
|
||||
@@ -111,13 +193,13 @@ async def get_usage_status(
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
daily_token_limit: Max tokens per day (0 = unlimited).
|
||||
weekly_token_limit: Max tokens per week (0 = unlimited).
|
||||
daily_cost_limit: Max microdollars of spend per day (0 = unlimited).
|
||||
weekly_cost_limit: Max microdollars of spend per week (0 = unlimited).
|
||||
rate_limit_reset_cost: Credit cost (cents) to reset daily limit (0 = disabled).
|
||||
tier: The user's rate-limit tier (included in the response).
|
||||
|
||||
Returns:
|
||||
CoPilotUsageStatus with current usage and limits.
|
||||
CoPilotUsageStatus with current usage and limits in microdollars.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
daily_used = 0
|
||||
@@ -136,12 +218,12 @@ async def get_usage_status(
|
||||
return CoPilotUsageStatus(
|
||||
daily=UsageWindow(
|
||||
used=daily_used,
|
||||
limit=daily_token_limit,
|
||||
limit=daily_cost_limit,
|
||||
resets_at=_daily_reset_time(now=now),
|
||||
),
|
||||
weekly=UsageWindow(
|
||||
used=weekly_used,
|
||||
limit=weekly_token_limit,
|
||||
limit=weekly_cost_limit,
|
||||
resets_at=_weekly_reset_time(now=now),
|
||||
),
|
||||
tier=tier,
|
||||
@@ -151,22 +233,22 @@ async def get_usage_status(
|
||||
|
||||
async def check_rate_limit(
|
||||
user_id: str,
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
daily_cost_limit: int,
|
||||
weekly_cost_limit: int,
|
||||
) -> None:
|
||||
"""Check if user is within rate limits. Raises RateLimitExceeded if not.
|
||||
|
||||
This is a pre-turn soft check. The authoritative usage counter is updated
|
||||
by ``record_token_usage()`` after the turn completes. Under concurrency,
|
||||
by ``record_cost_usage()`` after the turn completes. Under concurrency,
|
||||
two parallel turns may both pass this check against the same snapshot.
|
||||
This is acceptable because token-based limits are approximate by nature
|
||||
(the exact token count is unknown until after generation).
|
||||
This is acceptable because cost-based limits are approximate by nature
|
||||
(the exact cost is unknown until after generation).
|
||||
|
||||
Fails open: if Redis is unavailable, allows the request.
|
||||
"""
|
||||
# Short-circuit: when both limits are 0 (unlimited) skip the Redis
|
||||
# round-trip entirely.
|
||||
if daily_token_limit <= 0 and weekly_token_limit <= 0:
|
||||
if daily_cost_limit <= 0 and weekly_cost_limit <= 0:
|
||||
return
|
||||
|
||||
now = datetime.now(UTC)
|
||||
@@ -182,26 +264,25 @@ async def check_rate_limit(
|
||||
logger.warning("Redis unavailable for rate limit check, allowing request")
|
||||
return
|
||||
|
||||
# Worst-case overshoot: N concurrent requests × ~15K tokens each.
|
||||
if daily_token_limit > 0 and daily_used >= daily_token_limit:
|
||||
if daily_cost_limit > 0 and daily_used >= daily_cost_limit:
|
||||
raise RateLimitExceeded("daily", _daily_reset_time(now=now))
|
||||
|
||||
if weekly_token_limit > 0 and weekly_used >= weekly_token_limit:
|
||||
if weekly_cost_limit > 0 and weekly_used >= weekly_cost_limit:
|
||||
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
|
||||
|
||||
|
||||
async def reset_daily_usage(user_id: str, daily_token_limit: int = 0) -> bool:
|
||||
"""Reset a user's daily token usage counter in Redis.
|
||||
async def reset_daily_usage(user_id: str, daily_cost_limit: int = 0) -> bool:
|
||||
"""Reset a user's daily cost usage counter in Redis.
|
||||
|
||||
Called after a user pays credits to extend their daily limit.
|
||||
Also reduces the weekly usage counter by ``daily_token_limit`` tokens
|
||||
Also reduces the weekly usage counter by ``daily_cost_limit`` microdollars
|
||||
(clamped to 0) so the user effectively gets one extra day's worth of
|
||||
weekly capacity.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
daily_token_limit: The configured daily token limit. When positive,
|
||||
the weekly counter is reduced by this amount.
|
||||
daily_cost_limit: The configured daily cost limit in microdollars.
|
||||
When positive, the weekly counter is reduced by this amount.
|
||||
|
||||
Returns False if Redis is unavailable so the caller can handle
|
||||
compensation (fail-closed for billed operations, unlike the read-only
|
||||
@@ -217,12 +298,12 @@ async def reset_daily_usage(user_id: str, daily_token_limit: int = 0) -> bool:
|
||||
# counter is not decremented — which would let the caller refund
|
||||
# credits even though the daily limit was already reset.
|
||||
d_key = _daily_key(user_id, now=now)
|
||||
w_key = _weekly_key(user_id, now=now) if daily_token_limit > 0 else None
|
||||
w_key = _weekly_key(user_id, now=now) if daily_cost_limit > 0 else None
|
||||
|
||||
pipe = redis.pipeline(transaction=True)
|
||||
pipe.delete(d_key)
|
||||
if w_key is not None:
|
||||
pipe.decrby(w_key, daily_token_limit)
|
||||
pipe.decrby(w_key, daily_cost_limit)
|
||||
results = await pipe.execute()
|
||||
|
||||
# Clamp negative weekly counter to 0 (best-effort; not critical).
|
||||
@@ -295,84 +376,40 @@ async def increment_daily_reset_count(user_id: str) -> None:
|
||||
logger.warning("Redis unavailable for tracking reset count")
|
||||
|
||||
|
||||
async def record_token_usage(
|
||||
async def record_cost_usage(
|
||||
user_id: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
*,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
model_cost_multiplier: float = 1.0,
|
||||
cost_microdollars: int,
|
||||
) -> None:
|
||||
"""Record token usage for a user across all windows.
|
||||
"""Record a user's generation spend against daily and weekly counters.
|
||||
|
||||
Uses cost-weighted counting so cached tokens don't unfairly penalise
|
||||
multi-turn conversations. Anthropic's pricing:
|
||||
- uncached input: 100%
|
||||
- cache creation: 25%
|
||||
- cache read: 10%
|
||||
- output: 100%
|
||||
|
||||
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
|
||||
from the API response). Cache counts are passed separately.
|
||||
|
||||
``model_cost_multiplier`` scales the final weighted total to reflect
|
||||
relative model cost. Use 5.0 for Opus (5× more expensive than Sonnet)
|
||||
so that Opus turns deplete the rate limit faster, proportional to cost.
|
||||
``cost_microdollars`` is the real generation cost reported by the
|
||||
provider (OpenRouter's ``usage.cost`` or the Claude Agent SDK's
|
||||
``total_cost_usd`` converted to microdollars). Because the provider
|
||||
cost already reflects model pricing and cache discounts, this function
|
||||
carries no pricing table or weighting — it just increments counters.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
prompt_tokens: Uncached input tokens.
|
||||
completion_tokens: Output tokens.
|
||||
cache_read_tokens: Tokens served from prompt cache (10% cost).
|
||||
cache_creation_tokens: Tokens written to prompt cache (25% cost).
|
||||
model_cost_multiplier: Relative model cost factor (1.0 = Sonnet, 5.0 = Opus).
|
||||
cost_microdollars: Spend to record in microdollars (1 USD = 1_000_000).
|
||||
Non-positive values are ignored.
|
||||
"""
|
||||
prompt_tokens = max(0, prompt_tokens)
|
||||
completion_tokens = max(0, completion_tokens)
|
||||
cache_read_tokens = max(0, cache_read_tokens)
|
||||
cache_creation_tokens = max(0, cache_creation_tokens)
|
||||
|
||||
weighted_input = (
|
||||
prompt_tokens
|
||||
+ round(cache_creation_tokens * 0.25)
|
||||
+ round(cache_read_tokens * 0.1)
|
||||
)
|
||||
total = round(
|
||||
(weighted_input + completion_tokens) * max(1.0, model_cost_multiplier)
|
||||
)
|
||||
if total <= 0:
|
||||
cost_microdollars = max(0, cost_microdollars)
|
||||
if cost_microdollars <= 0:
|
||||
return
|
||||
|
||||
raw_total = (
|
||||
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
|
||||
)
|
||||
logger.info(
|
||||
"Recording token usage for %s: raw=%d, weighted=%d, multiplier=%.1fx "
|
||||
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
|
||||
user_id[:8],
|
||||
raw_total,
|
||||
total,
|
||||
model_cost_multiplier,
|
||||
prompt_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_tokens,
|
||||
completion_tokens,
|
||||
)
|
||||
logger.info("Recording copilot spend: %d microdollars", cost_microdollars)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
# transaction=False: these are independent INCRBY+EXPIRE pairs on
|
||||
# separate keys — no cross-key atomicity needed. Skipping
|
||||
# MULTI/EXEC avoids the overhead. If the connection drops between
|
||||
# INCRBY and EXPIRE the key survives until the next date-based key
|
||||
# rotation (daily/weekly), so the memory-leak risk is negligible.
|
||||
pipe = redis.pipeline(transaction=False)
|
||||
# Use MULTI/EXEC so each INCRBY/EXPIRE pair is atomic — guarantees
|
||||
# the TTL is set even if the connection drops mid-pipeline, so
|
||||
# counters can never survive past their date-based rotation window.
|
||||
pipe = redis.pipeline(transaction=True)
|
||||
|
||||
# Daily counter (expires at next midnight UTC)
|
||||
d_key = _daily_key(user_id, now=now)
|
||||
pipe.incrby(d_key, total)
|
||||
pipe.incrby(d_key, cost_microdollars)
|
||||
seconds_until_daily_reset = int(
|
||||
(_daily_reset_time(now=now) - now).total_seconds()
|
||||
)
|
||||
@@ -380,7 +417,7 @@ async def record_token_usage(
|
||||
|
||||
# Weekly counter (expires end of week)
|
||||
w_key = _weekly_key(user_id, now=now)
|
||||
pipe.incrby(w_key, total)
|
||||
pipe.incrby(w_key, cost_microdollars)
|
||||
seconds_until_weekly_reset = int(
|
||||
(_weekly_reset_time(now=now) - now).total_seconds()
|
||||
)
|
||||
@@ -389,8 +426,8 @@ async def record_token_usage(
|
||||
await pipe.execute()
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning(
|
||||
"Redis unavailable for recording token usage (tokens=%d)",
|
||||
total,
|
||||
"Redis unavailable for recording cost usage (microdollars=%d)",
|
||||
cost_microdollars,
|
||||
)
|
||||
|
||||
|
||||
@@ -459,8 +496,20 @@ get_user_tier.cache_delete = _fetch_user_tier.cache_delete # type: ignore[attr-
|
||||
async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None:
|
||||
"""Persist the user's rate-limit tier to the database.
|
||||
|
||||
Also invalidates the ``get_user_tier`` cache for this user so that
|
||||
subsequent rate-limit checks immediately see the new tier.
|
||||
Invalidates every cache that keys off the user's subscription tier so the
|
||||
change is visible immediately: this function's own ``get_user_tier``, the
|
||||
shared ``get_user_by_id`` (which exposes ``user.subscription_tier``), and
|
||||
``get_pending_subscription_change`` (since an admin override can invalidate
|
||||
a cached ``cancel_at_period_end`` or schedule-based pending state).
|
||||
|
||||
If the user has an active Stripe subscription whose current price does not
|
||||
match ``tier``, Stripe will keep billing the old price and the next
|
||||
``customer.subscription.updated`` webhook will overwrite the DB tier back
|
||||
to whatever Stripe has. Proper reconciliation (cancelling or modifying the
|
||||
Stripe subscription when an admin overrides the tier) is out of scope for
|
||||
this PR — it changes the admin contract and needs its own test coverage.
|
||||
For now we emit a ``WARNING`` so drift surfaces via Sentry until that
|
||||
follow-up lands.
|
||||
|
||||
Raises:
|
||||
prisma.errors.RecordNotFoundError: If the user does not exist.
|
||||
@@ -469,8 +518,113 @@ async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None:
|
||||
where={"id": user_id},
|
||||
data={"subscriptionTier": tier.value},
|
||||
)
|
||||
# Invalidate cached tier so rate-limit checks pick up the change immediately.
|
||||
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
|
||||
# Local import required: backend.data.credit imports backend.copilot.rate_limit
|
||||
# (via get_user_tier in credit.py's _invalidate_user_tier_caches), so a
|
||||
# top-level ``from backend.data.credit import ...`` here would create a
|
||||
# circular import at module-load time.
|
||||
from backend.data.credit import get_pending_subscription_change
|
||||
|
||||
get_user_by_id.cache_delete(user_id) # type: ignore[attr-defined]
|
||||
get_pending_subscription_change.cache_delete(user_id) # type: ignore[attr-defined]
|
||||
|
||||
# The DB write above is already committed; the drift check is best-effort
|
||||
# diagnostic logging. Fire-and-forget so admin bulk ops don't wait on a
|
||||
# Stripe roundtrip. The inner helper wraps its body in a timeout + broad
|
||||
# except so background task errors still surface via logs rather than as
|
||||
# "task exception never retrieved" warnings. Cancellation on request
|
||||
# shutdown is acceptable — the drift warning is non-load-bearing.
|
||||
asyncio.ensure_future(_drift_check_background(user_id, tier))
|
||||
|
||||
|
||||
async def _drift_check_background(user_id: str, tier: SubscriptionTier) -> None:
|
||||
"""Run the Stripe drift check in the background, logging rather than raising."""
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
_warn_if_stripe_subscription_drifts(user_id, tier),
|
||||
timeout=5.0,
|
||||
)
|
||||
logger.debug(
|
||||
"set_user_tier: drift check completed for user=%s admin_tier=%s",
|
||||
user_id,
|
||||
tier.value,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"set_user_tier: drift check timed out for user=%s admin_tier=%s",
|
||||
user_id,
|
||||
tier.value,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
# Request may have completed and the event loop is cancelling tasks —
|
||||
# the drift log is non-critical, so accept cancellation silently.
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"set_user_tier: drift check background task failed for"
|
||||
" user=%s admin_tier=%s",
|
||||
user_id,
|
||||
tier.value,
|
||||
)
|
||||
|
||||
|
||||
async def _warn_if_stripe_subscription_drifts(
|
||||
user_id: str, new_tier: SubscriptionTier
|
||||
) -> None:
|
||||
"""Emit a WARNING when an admin tier override leaves an active Stripe sub on a
|
||||
mismatched price.
|
||||
|
||||
The warning is diagnostic only: Stripe remains the billing source of truth,
|
||||
so the next ``customer.subscription.updated`` webhook will reset the DB
|
||||
tier. Surfacing the drift here lets ops catch admin overrides that bypass
|
||||
the intended Checkout / Portal cancel flows before users notice surprise
|
||||
charges.
|
||||
"""
|
||||
# Local imports: see note in ``set_user_tier`` about the credit <-> rate_limit
|
||||
# circular. These helpers (``_get_active_subscription``,
|
||||
# ``get_subscription_price_id``) live in credit.py alongside the rest of
|
||||
# the Stripe billing code.
|
||||
from backend.data.credit import _get_active_subscription, get_subscription_price_id
|
||||
|
||||
try:
|
||||
user = await get_user_by_id(user_id)
|
||||
if not getattr(user, "stripe_customer_id", None):
|
||||
return
|
||||
sub = await _get_active_subscription(user.stripe_customer_id)
|
||||
if sub is None:
|
||||
return
|
||||
items = sub["items"].data
|
||||
if not items:
|
||||
return
|
||||
price = items[0].price
|
||||
current_price_id = price if isinstance(price, str) else price.id
|
||||
# The LaunchDarkly-backed price lookup must live inside this try/except:
|
||||
# an LD SDK failure (network, token revoked) here would otherwise
|
||||
# propagate past set_user_tier's already-committed DB write and turn a
|
||||
# best-effort diagnostic into a 500 on admin tier writes.
|
||||
expected_price_id = await get_subscription_price_id(new_tier)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"_warn_if_stripe_subscription_drifts: drift lookup failed for"
|
||||
" user=%s; skipping drift warning",
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
if expected_price_id is not None and expected_price_id == current_price_id:
|
||||
return
|
||||
logger.warning(
|
||||
"Admin tier override will drift from Stripe: user=%s admin_tier=%s"
|
||||
" stripe_sub=%s stripe_price=%s expected_price=%s — the next"
|
||||
" customer.subscription.updated webhook will reconcile the DB tier"
|
||||
" back to whatever Stripe has; cancel or modify the Stripe subscription"
|
||||
" if you intended the admin override to stick.",
|
||||
user_id,
|
||||
new_tier.value,
|
||||
sub.id,
|
||||
current_price_id,
|
||||
expected_price_id,
|
||||
)
|
||||
|
||||
|
||||
async def get_global_rate_limits(
|
||||
@@ -480,37 +634,41 @@ async def get_global_rate_limits(
|
||||
) -> tuple[int, int, SubscriptionTier]:
|
||||
"""Resolve global rate limits from LaunchDarkly, falling back to config.
|
||||
|
||||
The base limits (from LD or config) are multiplied by the user's
|
||||
tier multiplier so that higher tiers receive proportionally larger
|
||||
allowances.
|
||||
Values are microdollars. The base limits (from LD or config) are
|
||||
multiplied by the user's tier multiplier so that higher tiers receive
|
||||
proportionally larger allowances.
|
||||
|
||||
Args:
|
||||
user_id: User ID for LD flag evaluation context.
|
||||
config_daily: Fallback daily limit from ChatConfig.
|
||||
config_weekly: Fallback weekly limit from ChatConfig.
|
||||
config_daily: Fallback daily cost limit (microdollars) from ChatConfig.
|
||||
config_weekly: Fallback weekly cost limit (microdollars) from ChatConfig.
|
||||
|
||||
Returns:
|
||||
(daily_token_limit, weekly_token_limit, tier) 3-tuple.
|
||||
(daily_cost_limit, weekly_cost_limit, tier) — limits in microdollars.
|
||||
"""
|
||||
# Lazy import to avoid circular dependency:
|
||||
# rate_limit -> feature_flag -> settings -> ... -> rate_limit
|
||||
from backend.util.feature_flag import Flag, get_feature_flag_value
|
||||
|
||||
daily_raw = await get_feature_flag_value(
|
||||
Flag.COPILOT_DAILY_TOKEN_LIMIT.value, user_id, config_daily
|
||||
)
|
||||
weekly_raw = await get_feature_flag_value(
|
||||
Flag.COPILOT_WEEKLY_TOKEN_LIMIT.value, user_id, config_weekly
|
||||
# Fetch daily + weekly flags in parallel — each LD evaluation is an
|
||||
# independent network round-trip, so gather cuts latency roughly in half.
|
||||
daily_raw, weekly_raw = await asyncio.gather(
|
||||
get_feature_flag_value(
|
||||
Flag.COPILOT_DAILY_COST_LIMIT.value, user_id, config_daily
|
||||
),
|
||||
get_feature_flag_value(
|
||||
Flag.COPILOT_WEEKLY_COST_LIMIT.value, user_id, config_weekly
|
||||
),
|
||||
)
|
||||
try:
|
||||
daily = max(0, int(daily_raw))
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("Invalid LD value for daily token limit: %r", daily_raw)
|
||||
logger.warning("Invalid LD value for daily cost limit: %r", daily_raw)
|
||||
daily = config_daily
|
||||
try:
|
||||
weekly = max(0, int(weekly_raw))
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("Invalid LD value for weekly token limit: %r", weekly_raw)
|
||||
logger.warning("Invalid LD value for weekly cost limit: %r", weekly_raw)
|
||||
weekly = config_weekly
|
||||
|
||||
# Apply tier multiplier
|
||||
|
||||
@@ -24,7 +24,7 @@ from .rate_limit import (
|
||||
get_usage_status,
|
||||
get_user_tier,
|
||||
increment_daily_reset_count,
|
||||
record_token_usage,
|
||||
record_cost_usage,
|
||||
release_reset_lock,
|
||||
reset_daily_usage,
|
||||
reset_user_usage,
|
||||
@@ -82,7 +82,7 @@ class TestGetUsageStatus:
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
_USER, daily_cost_limit=10000, weekly_cost_limit=50000
|
||||
)
|
||||
|
||||
assert isinstance(status, CoPilotUsageStatus)
|
||||
@@ -98,7 +98,7 @@ class TestGetUsageStatus:
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
_USER, daily_cost_limit=10000, weekly_cost_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 0
|
||||
@@ -115,7 +115,7 @@ class TestGetUsageStatus:
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
_USER, daily_cost_limit=10000, weekly_cost_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 0
|
||||
@@ -132,7 +132,7 @@ class TestGetUsageStatus:
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
_USER, daily_cost_limit=10000, weekly_cost_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 500
|
||||
@@ -148,7 +148,7 @@ class TestGetUsageStatus:
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
_USER, daily_cost_limit=10000, weekly_cost_limit=50000
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
@@ -174,7 +174,7 @@ class TestCheckRateLimit:
|
||||
):
|
||||
# Should not raise
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
_USER, daily_cost_limit=10000, weekly_cost_limit=50000
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -188,7 +188,7 @@ class TestCheckRateLimit:
|
||||
):
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
_USER, daily_cost_limit=10000, weekly_cost_limit=50000
|
||||
)
|
||||
assert exc_info.value.window == "daily"
|
||||
|
||||
@@ -203,7 +203,7 @@ class TestCheckRateLimit:
|
||||
):
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
_USER, daily_cost_limit=10000, weekly_cost_limit=50000
|
||||
)
|
||||
assert exc_info.value.window == "weekly"
|
||||
|
||||
@@ -216,7 +216,7 @@ class TestCheckRateLimit:
|
||||
):
|
||||
# Should not raise
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
_USER, daily_cost_limit=10000, weekly_cost_limit=50000
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -229,15 +229,15 @@ class TestCheckRateLimit:
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise — limits of 0 mean unlimited
|
||||
await check_rate_limit(_USER, daily_token_limit=0, weekly_token_limit=0)
|
||||
await check_rate_limit(_USER, daily_cost_limit=0, weekly_cost_limit=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# record_token_usage
|
||||
# record_cost_usage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecordTokenUsage:
|
||||
class TestRecordCostUsage:
|
||||
@staticmethod
|
||||
def _make_pipeline_mock() -> MagicMock:
|
||||
"""Create a pipeline mock with sync methods and async execute."""
|
||||
@@ -255,27 +255,40 @@ class TestRecordTokenUsage:
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
await record_cost_usage(_USER, cost_microdollars=123_456)
|
||||
|
||||
# Should call incrby twice (daily + weekly) with total=150
|
||||
# Should call incrby twice (daily + weekly) with the same cost
|
||||
incrby_calls = mock_pipe.incrby.call_args_list
|
||||
assert len(incrby_calls) == 2
|
||||
assert incrby_calls[0].args[1] == 150 # daily
|
||||
assert incrby_calls[1].args[1] == 150 # weekly
|
||||
assert incrby_calls[0].args[1] == 123_456 # daily
|
||||
assert incrby_calls[1].args[1] == 123_456 # weekly
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_zero_tokens(self):
|
||||
async def test_skips_when_cost_is_zero(self):
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=0, completion_tokens=0)
|
||||
await record_cost_usage(_USER, cost_microdollars=0)
|
||||
|
||||
# Should not call pipeline at all
|
||||
mock_redis.pipeline.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_cost_is_negative(self):
|
||||
"""Negative costs are clamped to zero and skip the pipeline."""
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_cost_usage(_USER, cost_microdollars=-10)
|
||||
|
||||
mock_redis.pipeline.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sets_expire_on_both_keys(self):
|
||||
"""Pipeline should call expire for both daily and weekly keys."""
|
||||
@@ -287,7 +300,7 @@ class TestRecordTokenUsage:
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
await record_cost_usage(_USER, cost_microdollars=5_000)
|
||||
|
||||
expire_calls = mock_pipe.expire.call_args_list
|
||||
assert len(expire_calls) == 2
|
||||
@@ -308,32 +321,7 @@ class TestRecordTokenUsage:
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
# Should not raise
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_weighted_counting(self):
|
||||
"""Cached tokens should be weighted: cache_read=10%, cache_create=25%."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(
|
||||
_USER,
|
||||
prompt_tokens=100, # uncached → 100
|
||||
completion_tokens=50, # output → 50
|
||||
cache_read_tokens=10000, # 10% → 1000
|
||||
cache_creation_tokens=400, # 25% → 100
|
||||
)
|
||||
|
||||
# Expected weighted total: 100 + 1000 + 100 + 50 = 1250
|
||||
incrby_calls = mock_pipe.incrby.call_args_list
|
||||
assert len(incrby_calls) == 2
|
||||
assert incrby_calls[0].args[1] == 1250 # daily
|
||||
assert incrby_calls[1].args[1] == 1250 # weekly
|
||||
await record_cost_usage(_USER, cost_microdollars=5_000)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_redis_error_during_pipeline_execute(self):
|
||||
@@ -348,7 +336,7 @@ class TestRecordTokenUsage:
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise — fail-open
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
await record_cost_usage(_USER, cost_microdollars=5_000)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -581,6 +569,80 @@ class TestSetUserTier:
|
||||
|
||||
assert tier_after == SubscriptionTier.ENTERPRISE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drift_check_swallows_launchdarkly_failure(self):
|
||||
"""LaunchDarkly price-id lookup failures inside the drift check must
|
||||
never bubble up and 500 the admin tier write — the DB update is
|
||||
already committed by the time we check drift."""
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.update = AsyncMock(return_value=None)
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.stripe_customer_id = "cus_abc"
|
||||
|
||||
mock_sub = MagicMock()
|
||||
mock_sub.id = "sub_abc"
|
||||
mock_sub["items"].data = [MagicMock(price=MagicMock(id="price_mismatch"))]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit._get_active_subscription",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_sub,
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.get_subscription_price_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LD SDK not initialized"),
|
||||
),
|
||||
):
|
||||
# Must NOT raise — drift check is best-effort diagnostic only.
|
||||
await set_user_tier(_USER, SubscriptionTier.PRO)
|
||||
|
||||
mock_prisma.update.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drift_check_timeout_is_bounded(self):
|
||||
"""A Stripe call that stalls on the 80s SDK default must not block the
|
||||
admin tier write — set_user_tier wraps the drift check in a 5s timeout
|
||||
and logs + returns on TimeoutError."""
|
||||
import asyncio as _asyncio
|
||||
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.update = AsyncMock(return_value=None)
|
||||
|
||||
async def _never_returns(_user_id: str, _tier):
|
||||
await _asyncio.sleep(60)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit._warn_if_stripe_subscription_drifts",
|
||||
side_effect=_never_returns,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.asyncio.wait_for",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=_asyncio.TimeoutError,
|
||||
),
|
||||
):
|
||||
await set_user_tier(_USER, SubscriptionTier.PRO)
|
||||
|
||||
# Set_user_tier still completed — the drift timeout did not propagate.
|
||||
mock_prisma.update.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_global_rate_limits with tiers
|
||||
@@ -745,7 +807,7 @@ class TestTierLimitsRespected:
|
||||
assert tier == SubscriptionTier.PRO
|
||||
# Should NOT raise — 3M < 12.5M
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=daily, weekly_token_limit=weekly
|
||||
_USER, daily_cost_limit=daily, weekly_cost_limit=weekly
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -779,7 +841,7 @@ class TestTierLimitsRespected:
|
||||
# Should raise — 2.5M >= 2.5M
|
||||
with pytest.raises(RateLimitExceeded):
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=daily, weekly_token_limit=weekly
|
||||
_USER, daily_cost_limit=daily, weekly_cost_limit=weekly
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -811,7 +873,7 @@ class TestTierLimitsRespected:
|
||||
assert tier == SubscriptionTier.ENTERPRISE
|
||||
# Should NOT raise — 100M < 150M
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=daily, weekly_token_limit=weekly
|
||||
_USER, daily_cost_limit=daily, weekly_cost_limit=weekly
|
||||
)
|
||||
|
||||
|
||||
@@ -838,7 +900,7 @@ class TestResetDailyUsage:
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
result = await reset_daily_usage(_USER, daily_token_limit=10000)
|
||||
result = await reset_daily_usage(_USER, daily_cost_limit=10000)
|
||||
|
||||
assert result is True
|
||||
mock_pipe.delete.assert_called_once()
|
||||
@@ -854,7 +916,7 @@ class TestResetDailyUsage:
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await reset_daily_usage(_USER, daily_token_limit=10000)
|
||||
await reset_daily_usage(_USER, daily_cost_limit=10000)
|
||||
|
||||
mock_pipe.decrby.assert_called_once()
|
||||
mock_redis.set.assert_not_called() # 35000 > 0, no clamp needed
|
||||
@@ -870,14 +932,14 @@ class TestResetDailyUsage:
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await reset_daily_usage(_USER, daily_token_limit=10000)
|
||||
await reset_daily_usage(_USER, daily_cost_limit=10000)
|
||||
|
||||
mock_pipe.decrby.assert_called_once()
|
||||
mock_redis.set.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_weekly_reduction_when_daily_limit_zero(self):
|
||||
"""When daily_token_limit is 0, weekly counter should not be touched."""
|
||||
"""When daily_cost_limit is 0, weekly counter should not be touched."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_pipe.execute = AsyncMock(return_value=[1]) # only delete result
|
||||
mock_redis = AsyncMock()
|
||||
@@ -887,7 +949,7 @@ class TestResetDailyUsage:
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await reset_daily_usage(_USER, daily_token_limit=0)
|
||||
await reset_daily_usage(_USER, daily_cost_limit=0)
|
||||
|
||||
mock_pipe.delete.assert_called_once()
|
||||
mock_pipe.decrby.assert_not_called()
|
||||
@@ -898,7 +960,7 @@ class TestResetDailyUsage:
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
result = await reset_daily_usage(_USER, daily_token_limit=10000)
|
||||
result = await reset_daily_usage(_USER, daily_cost_limit=10000)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@@ -16,14 +16,14 @@ from backend.util.exceptions import InsufficientBalanceError
|
||||
# Minimal config mock matching ChatConfig fields used by the endpoint.
|
||||
def _make_config(
|
||||
rate_limit_reset_cost: int = 500,
|
||||
daily_token_limit: int = 2_500_000,
|
||||
weekly_token_limit: int = 12_500_000,
|
||||
daily_cost_limit_microdollars: int = 10_000_000,
|
||||
weekly_cost_limit_microdollars: int = 50_000_000,
|
||||
max_daily_resets: int = 5,
|
||||
):
|
||||
cfg = MagicMock()
|
||||
cfg.rate_limit_reset_cost = rate_limit_reset_cost
|
||||
cfg.daily_token_limit = daily_token_limit
|
||||
cfg.weekly_token_limit = weekly_token_limit
|
||||
cfg.daily_cost_limit_microdollars = daily_cost_limit_microdollars
|
||||
cfg.weekly_cost_limit_microdollars = weekly_cost_limit_microdollars
|
||||
cfg.max_daily_resets = max_daily_resets
|
||||
return cfg
|
||||
|
||||
@@ -77,10 +77,10 @@ class TestResetCopilotUsage:
|
||||
assert "not available" in exc_info.value.detail
|
||||
|
||||
async def test_no_daily_limit_returns_400(self):
|
||||
"""When daily_token_limit=0 (unlimited), endpoint returns 400."""
|
||||
"""When daily_cost_limit=0 (unlimited), endpoint returns 400."""
|
||||
|
||||
with (
|
||||
patch(f"{_MODULE}.config", _make_config(daily_token_limit=0)),
|
||||
patch(f"{_MODULE}.config", _make_config(daily_cost_limit_microdollars=0)),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(daily=0),
|
||||
):
|
||||
|
||||
@@ -34,6 +34,15 @@ class ResponseType(str, Enum):
|
||||
TEXT_DELTA = "text-delta"
|
||||
TEXT_END = "text-end"
|
||||
|
||||
# Reasoning streaming (extended_thinking content blocks). Matches
|
||||
# the Vercel AI SDK v5 wire names so the client's ``useChat``
|
||||
# transport accumulates these into a ``type: 'reasoning'`` UIMessage
|
||||
# part that the ``ReasoningCollapse`` component renders collapsed by
|
||||
# default.
|
||||
REASONING_START = "reasoning-start"
|
||||
REASONING_DELTA = "reasoning-delta"
|
||||
REASONING_END = "reasoning-end"
|
||||
|
||||
# Tool interaction
|
||||
TOOL_INPUT_START = "tool-input-start"
|
||||
TOOL_INPUT_AVAILABLE = "tool-input-available"
|
||||
@@ -130,6 +139,31 @@ class StreamTextEnd(StreamBaseResponse):
|
||||
id: str = Field(..., description="Text block ID")
|
||||
|
||||
|
||||
# ========== Reasoning Streaming ==========
|
||||
|
||||
|
||||
class StreamReasoningStart(StreamBaseResponse):
|
||||
"""Start of a reasoning block (extended_thinking content)."""
|
||||
|
||||
type: ResponseType = ResponseType.REASONING_START
|
||||
id: str = Field(..., description="Reasoning block ID")
|
||||
|
||||
|
||||
class StreamReasoningDelta(StreamBaseResponse):
|
||||
"""Streaming reasoning content delta."""
|
||||
|
||||
type: ResponseType = ResponseType.REASONING_DELTA
|
||||
id: str = Field(..., description="Reasoning block ID")
|
||||
delta: str = Field(..., description="Reasoning content delta")
|
||||
|
||||
|
||||
class StreamReasoningEnd(StreamBaseResponse):
|
||||
"""End of a reasoning block."""
|
||||
|
||||
type: ResponseType = ResponseType.REASONING_END
|
||||
id: str = Field(..., description="Reasoning block ID")
|
||||
|
||||
|
||||
# ========== Tool Interaction ==========
|
||||
|
||||
|
||||
|
||||
@@ -24,14 +24,10 @@ from typing import TYPE_CHECKING, Any
|
||||
# Static imports for type checkers so they can resolve __all__ entries
|
||||
# without executing the lazy-import machinery at runtime.
|
||||
if TYPE_CHECKING:
|
||||
from .collect import CopilotResult as CopilotResult
|
||||
from .collect import collect_copilot_response as collect_copilot_response
|
||||
from .service import stream_chat_completion_sdk as stream_chat_completion_sdk
|
||||
from .tool_adapter import create_copilot_mcp_server as create_copilot_mcp_server
|
||||
|
||||
__all__ = [
|
||||
"CopilotResult",
|
||||
"collect_copilot_response",
|
||||
"stream_chat_completion_sdk",
|
||||
"create_copilot_mcp_server",
|
||||
]
|
||||
@@ -39,8 +35,6 @@ __all__ = [
|
||||
# Dispatch table for PEP 562 lazy imports. Each entry is a (module, attr)
|
||||
# pair so new exports can be added without touching __getattr__ itself.
|
||||
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
|
||||
"CopilotResult": (".collect", "CopilotResult"),
|
||||
"collect_copilot_response": (".collect", "collect_copilot_response"),
|
||||
"stream_chat_completion_sdk": (".service", "stream_chat_completion_sdk"),
|
||||
"create_copilot_mcp_server": (".tool_adapter", "create_copilot_mcp_server"),
|
||||
}
|
||||
|
||||
@@ -1,232 +0,0 @@
|
||||
"""Public helpers for consuming a copilot stream as a simple request-response.
|
||||
|
||||
This module exposes :class:`CopilotResult` and :func:`collect_copilot_response`
|
||||
so that callers (e.g. the AutoPilot block) can consume the copilot stream
|
||||
without implementing their own event loop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from .. import stream_registry
|
||||
from ..response_model import (
|
||||
StreamError,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from .service import stream_chat_completion_sdk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Identifiers used when registering AutoPilot-originated streams in the
|
||||
# stream registry. Distinct from "chat_stream"/"chat" used by the HTTP SSE
|
||||
# endpoint, making it easy to filter AutoPilot streams in logs/observability.
|
||||
AUTOPILOT_TOOL_CALL_ID = "autopilot_stream"
|
||||
AUTOPILOT_TOOL_NAME = "autopilot"
|
||||
|
||||
|
||||
class CopilotResult:
|
||||
"""Aggregated result from consuming a copilot stream.
|
||||
|
||||
Returned by :func:`collect_copilot_response` so callers don't need to
|
||||
implement their own event-loop over the raw stream events.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"response_text",
|
||||
"tool_calls",
|
||||
"prompt_tokens",
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.response_text: str = ""
|
||||
self.tool_calls: list[dict[str, Any]] = []
|
||||
self.prompt_tokens: int = 0
|
||||
self.completion_tokens: int = 0
|
||||
self.total_tokens: int = 0
|
||||
|
||||
|
||||
class _RegistryHandle(BaseModel):
|
||||
"""Tracks stream registry session state for cleanup."""
|
||||
|
||||
publish_turn_id: str = ""
|
||||
error_msg: str | None = None
|
||||
error_already_published: bool = False
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _registry_session(
|
||||
session_id: str, user_id: str, turn_id: str
|
||||
) -> AsyncIterator[_RegistryHandle]:
|
||||
"""Create a stream registry session and ensure it is finalized."""
|
||||
handle = _RegistryHandle(publish_turn_id=turn_id)
|
||||
try:
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id=AUTOPILOT_TOOL_CALL_ID,
|
||||
tool_name=AUTOPILOT_TOOL_NAME,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning(
|
||||
"[collect] Failed to create stream registry session for %s, "
|
||||
"frontend will not receive real-time updates",
|
||||
session_id[:12],
|
||||
exc_info=True,
|
||||
)
|
||||
# Disable chunk publishing but keep finalization enabled so
|
||||
# mark_session_completed can clean up any partial registry state.
|
||||
handle.publish_turn_id = ""
|
||||
|
||||
try:
|
||||
yield handle
|
||||
finally:
|
||||
try:
|
||||
await stream_registry.mark_session_completed(
|
||||
session_id,
|
||||
error_message=handle.error_msg,
|
||||
skip_error_publish=handle.error_already_published,
|
||||
)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning(
|
||||
"[collect] Failed to mark stream completed for %s",
|
||||
session_id[:12],
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
class _ToolCallEntry(BaseModel):
|
||||
"""A single tool call observed during stream consumption."""
|
||||
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
input: Any
|
||||
output: Any = None
|
||||
success: bool | None = None
|
||||
|
||||
|
||||
class _EventAccumulator(BaseModel):
|
||||
"""Mutable accumulator for stream events."""
|
||||
|
||||
response_parts: list[str] = Field(default_factory=list)
|
||||
tool_calls: list[_ToolCallEntry] = Field(default_factory=list)
|
||||
tool_calls_by_id: dict[str, _ToolCallEntry] = Field(default_factory=dict)
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
def _process_event(event: object, acc: _EventAccumulator) -> str | None:
|
||||
"""Process a single stream event and return error_msg if StreamError.
|
||||
|
||||
Uses structural pattern matching for dispatch per project guidelines.
|
||||
"""
|
||||
match event:
|
||||
case StreamTextDelta(delta=delta):
|
||||
acc.response_parts.append(delta)
|
||||
case StreamToolInputAvailable() as e:
|
||||
entry = _ToolCallEntry(
|
||||
tool_call_id=e.toolCallId,
|
||||
tool_name=e.toolName,
|
||||
input=e.input,
|
||||
)
|
||||
acc.tool_calls.append(entry)
|
||||
acc.tool_calls_by_id[e.toolCallId] = entry
|
||||
case StreamToolOutputAvailable() as e:
|
||||
if tc := acc.tool_calls_by_id.get(e.toolCallId):
|
||||
tc.output = e.output
|
||||
tc.success = e.success
|
||||
else:
|
||||
logger.debug(
|
||||
"Received tool output for unknown tool_call_id: %s",
|
||||
e.toolCallId,
|
||||
)
|
||||
case StreamUsage() as e:
|
||||
acc.prompt_tokens += e.prompt_tokens
|
||||
acc.completion_tokens += e.completion_tokens
|
||||
acc.total_tokens += e.total_tokens
|
||||
case StreamError(errorText=err):
|
||||
return err
|
||||
return None
|
||||
|
||||
|
||||
async def collect_copilot_response(
|
||||
*,
|
||||
session_id: str,
|
||||
message: str,
|
||||
user_id: str,
|
||||
is_user_message: bool = True,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> CopilotResult:
|
||||
"""Consume :func:`stream_chat_completion_sdk` and return aggregated results.
|
||||
|
||||
Registers with the stream registry so the frontend can connect via SSE
|
||||
and receive real-time updates while the AutoPilot block is executing.
|
||||
|
||||
Args:
|
||||
session_id: Chat session to use.
|
||||
message: The user message / prompt.
|
||||
user_id: Authenticated user ID.
|
||||
is_user_message: Whether this is a user-initiated message.
|
||||
permissions: Optional capability filter. When provided, restricts
|
||||
which tools and blocks the copilot may use during this execution.
|
||||
|
||||
Returns:
|
||||
A :class:`CopilotResult` with the aggregated response text,
|
||||
tool calls, and token usage.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the stream yields a ``StreamError`` event.
|
||||
"""
|
||||
turn_id = str(uuid.uuid4())
|
||||
async with _registry_session(session_id, user_id, turn_id) as handle:
|
||||
try:
|
||||
raw_stream = stream_chat_completion_sdk(
|
||||
session_id=session_id,
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id,
|
||||
permissions=permissions,
|
||||
)
|
||||
published_stream = stream_registry.stream_and_publish(
|
||||
session_id=session_id,
|
||||
turn_id=handle.publish_turn_id,
|
||||
stream=raw_stream,
|
||||
)
|
||||
|
||||
acc = _EventAccumulator()
|
||||
async for event in published_stream:
|
||||
if err := _process_event(event, acc):
|
||||
handle.error_msg = err
|
||||
# stream_and_publish skips StreamError events, so
|
||||
# mark_session_completed must publish the error to Redis.
|
||||
handle.error_already_published = False
|
||||
raise RuntimeError(f"Copilot error: {err}")
|
||||
except Exception:
|
||||
if handle.error_msg is None:
|
||||
handle.error_msg = "AutoPilot execution failed"
|
||||
raise
|
||||
|
||||
result = CopilotResult()
|
||||
result.response_text = "".join(acc.response_parts)
|
||||
result.tool_calls = [tc.model_dump() for tc in acc.tool_calls]
|
||||
result.prompt_tokens = acc.prompt_tokens
|
||||
result.completion_tokens = acc.completion_tokens
|
||||
result.total_tokens = acc.total_tokens
|
||||
return result
|
||||
@@ -1,177 +0,0 @@
|
||||
"""Tests for collect_copilot_response stream registry integration."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from backend.copilot.sdk.collect import collect_copilot_response
|
||||
|
||||
|
||||
def _mock_stream_fn(*events):
|
||||
"""Return a callable that returns an async generator."""
|
||||
|
||||
async def _gen(**_kwargs):
|
||||
for e in events:
|
||||
yield e
|
||||
|
||||
return _gen
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_registry():
|
||||
"""Patch stream_registry module used by collect."""
|
||||
with patch("backend.copilot.sdk.collect.stream_registry") as m:
|
||||
m.create_session = AsyncMock()
|
||||
m.publish_chunk = AsyncMock()
|
||||
m.mark_session_completed = AsyncMock()
|
||||
|
||||
# stream_and_publish: pass-through that also publishes (real logic)
|
||||
# We re-implement the pass-through here so the event loop works,
|
||||
# but still track publish_chunk calls via the mock.
|
||||
async def _stream_and_publish(session_id, turn_id, stream):
|
||||
async for event in stream:
|
||||
if turn_id and not isinstance(event, (StreamFinish, StreamError)):
|
||||
await m.publish_chunk(turn_id, event)
|
||||
yield event
|
||||
|
||||
m.stream_and_publish = _stream_and_publish
|
||||
yield m
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stream_fn_patch():
|
||||
"""Helper to patch stream_chat_completion_sdk."""
|
||||
|
||||
def _patch(events):
|
||||
return patch(
|
||||
"backend.copilot.sdk.collect.stream_chat_completion_sdk",
|
||||
new=_mock_stream_fn(*events),
|
||||
)
|
||||
|
||||
return _patch
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_registry_called_on_success(mock_registry, stream_fn_patch):
|
||||
"""Stream registry create/publish/complete are called correctly on success."""
|
||||
events = [
|
||||
StreamTextDelta(id="t1", delta="Hello "),
|
||||
StreamTextDelta(id="t1", delta="world"),
|
||||
StreamUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
StreamFinish(),
|
||||
]
|
||||
|
||||
with stream_fn_patch(events):
|
||||
result = await collect_copilot_response(
|
||||
session_id="test-session",
|
||||
message="hi",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert result.response_text == "Hello world"
|
||||
assert result.total_tokens == 15
|
||||
|
||||
mock_registry.create_session.assert_awaited_once()
|
||||
# StreamFinish should NOT be published (mark_session_completed does it)
|
||||
published_types = [
|
||||
type(call.args[1]).__name__
|
||||
for call in mock_registry.publish_chunk.call_args_list
|
||||
]
|
||||
assert "StreamFinish" not in published_types
|
||||
assert "StreamTextDelta" in published_types
|
||||
|
||||
mock_registry.mark_session_completed.assert_awaited_once()
|
||||
_, kwargs = mock_registry.mark_session_completed.call_args
|
||||
assert kwargs.get("error_message") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_registry_error_on_stream_error(mock_registry, stream_fn_patch):
|
||||
"""mark_session_completed receives error message when StreamError occurs."""
|
||||
events = [
|
||||
StreamTextDelta(id="t1", delta="partial"),
|
||||
StreamError(errorText="something broke"),
|
||||
]
|
||||
|
||||
with stream_fn_patch(events):
|
||||
with pytest.raises(RuntimeError, match="something broke"):
|
||||
await collect_copilot_response(
|
||||
session_id="test-session",
|
||||
message="hi",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
_, kwargs = mock_registry.mark_session_completed.call_args
|
||||
assert kwargs.get("error_message") == "something broke"
|
||||
# stream_and_publish skips StreamError, so mark_session_completed must
|
||||
# publish it (skip_error_publish=False).
|
||||
assert kwargs.get("skip_error_publish") is False
|
||||
|
||||
# StreamError should NOT be published via publish_chunk — mark_session_completed
|
||||
# handles it to avoid double-publication.
|
||||
published_types = [
|
||||
type(call.args[1]).__name__
|
||||
for call in mock_registry.publish_chunk.call_args_list
|
||||
]
|
||||
assert "StreamError" not in published_types
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graceful_degradation_when_create_session_fails(
|
||||
mock_registry, stream_fn_patch
|
||||
):
|
||||
"""AutoPilot still works when stream registry create_session raises."""
|
||||
events = [
|
||||
StreamTextDelta(id="t1", delta="works"),
|
||||
StreamFinish(),
|
||||
]
|
||||
mock_registry.create_session = AsyncMock(side_effect=ConnectionError("Redis down"))
|
||||
|
||||
with stream_fn_patch(events):
|
||||
result = await collect_copilot_response(
|
||||
session_id="test-session",
|
||||
message="hi",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert result.response_text == "works"
|
||||
# publish_chunk should NOT be called because turn_id was cleared
|
||||
mock_registry.publish_chunk.assert_not_awaited()
|
||||
# mark_session_completed IS still called to clean up any partial state
|
||||
mock_registry.mark_session_completed.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_calls_published_and_collected(mock_registry, stream_fn_patch):
|
||||
"""Tool call events are both published to registry and collected in result."""
|
||||
events = [
|
||||
StreamToolInputAvailable(
|
||||
toolCallId="tc-1", toolName="read_file", input={"path": "/tmp"}
|
||||
),
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId="tc-1", output="file contents", success=True
|
||||
),
|
||||
StreamTextDelta(id="t1", delta="done"),
|
||||
StreamFinish(),
|
||||
]
|
||||
|
||||
with stream_fn_patch(events):
|
||||
result = await collect_copilot_response(
|
||||
session_id="test-session",
|
||||
message="hi",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0]["tool_name"] == "read_file"
|
||||
assert result.tool_calls[0]["output"] == "file contents"
|
||||
assert result.tool_calls[0]["success"] is True
|
||||
assert result.response_text == "done"
|
||||
@@ -84,9 +84,10 @@ async def test_resolve_file_ref_local_path_with_line_range():
|
||||
async def test_resolve_file_ref_rejects_path_outside_sdk_cwd():
|
||||
"""resolve_file_ref raises ValueError for paths outside sdk_cwd."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox_var:
|
||||
with (
|
||||
patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var,
|
||||
patch("backend.copilot.context._current_sandbox") as mock_sandbox_var,
|
||||
):
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
mock_sandbox_var.get.return_value = None
|
||||
|
||||
@@ -387,11 +388,13 @@ async def test_read_file_handler_local_file():
|
||||
with open(test_file, "w") as f:
|
||||
f.writelines(lines)
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
|
||||
"backend.copilot.context._current_project_dir"
|
||||
) as mock_proj_var, patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", _make_session()),
|
||||
with (
|
||||
patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var,
|
||||
patch("backend.copilot.context._current_project_dir") as mock_proj_var,
|
||||
patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", _make_session()),
|
||||
),
|
||||
):
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
# No project_dir set — so is_sdk_tool_path returns False for sdk_cwd paths
|
||||
@@ -413,12 +416,15 @@ async def test_read_file_handler_workspace_uri():
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.read_file_by_id.return_value = b"workspace file content\nline two\n"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", mock_session),
|
||||
), patch(
|
||||
"backend.copilot.sdk.file_ref.get_workspace_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", mock_session),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.file_ref.get_workspace_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
),
|
||||
):
|
||||
result = await _read_file_handler(
|
||||
{"file_path": "workspace://file-id-abc", "offset": 0, "limit": 10}
|
||||
@@ -446,11 +452,13 @@ async def test_read_file_handler_workspace_uri_no_session():
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_handler_access_denied():
|
||||
"""_read_file_handler rejects paths outside allowed locations."""
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox, patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", _make_session()),
|
||||
with (
|
||||
patch("backend.copilot.context._current_sdk_cwd") as mock_cwd,
|
||||
patch("backend.copilot.context._current_sandbox") as mock_sandbox,
|
||||
patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", _make_session()),
|
||||
),
|
||||
):
|
||||
mock_cwd.get.return_value = "/tmp/safe-dir"
|
||||
mock_sandbox.get.return_value = None
|
||||
@@ -490,11 +498,11 @@ async def test_read_file_bytes_e2b_sandbox_branch():
|
||||
mock_sandbox = AsyncMock()
|
||||
mock_sandbox.files.read.return_value = bytearray(b"sandbox content")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox_var, patch(
|
||||
"backend.copilot.context._current_project_dir"
|
||||
) as mock_proj:
|
||||
with (
|
||||
patch("backend.copilot.context._current_sdk_cwd") as mock_cwd,
|
||||
patch("backend.copilot.context._current_sandbox") as mock_sandbox_var,
|
||||
patch("backend.copilot.context._current_project_dir") as mock_proj,
|
||||
):
|
||||
mock_cwd.get.return_value = ""
|
||||
mock_sandbox_var.get.return_value = mock_sandbox
|
||||
mock_proj.get.return_value = ""
|
||||
@@ -513,11 +521,11 @@ async def test_read_file_bytes_e2b_path_escapes_sandbox_raises():
|
||||
session = _make_session()
|
||||
mock_sandbox = AsyncMock()
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox_var, patch(
|
||||
"backend.copilot.context._current_project_dir"
|
||||
) as mock_proj:
|
||||
with (
|
||||
patch("backend.copilot.context._current_sdk_cwd") as mock_cwd,
|
||||
patch("backend.copilot.context._current_sandbox") as mock_sandbox_var,
|
||||
patch("backend.copilot.context._current_project_dir") as mock_proj,
|
||||
):
|
||||
mock_cwd.get.return_value = ""
|
||||
mock_sandbox_var.get.return_value = mock_sandbox
|
||||
mock_proj.get.return_value = ""
|
||||
|
||||
@@ -1394,11 +1394,7 @@ async def test_e2e_toml_dict_with_list_value_to_concat_block():
|
||||
"""TOML dict with a list value → List[List[Any]] block: extracts list
|
||||
values, ignoring scalar values like 'title'."""
|
||||
toml_content = (
|
||||
'title = "Fruits"\n'
|
||||
"[[fruits]]\n"
|
||||
'name = "apple"\n'
|
||||
"[[fruits]]\n"
|
||||
'name = "banana"\n'
|
||||
'title = "Fruits"\n[[fruits]]\nname = "apple"\n[[fruits]]\nname = "banana"\n'
|
||||
)
|
||||
|
||||
async def _resolve(ref, *a, **kw): # noqa: ARG001
|
||||
@@ -1692,12 +1688,15 @@ async def test_media_file_field_passthrough_workspace_uri():
|
||||
},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=AssertionError("should not read file content")),
|
||||
), patch(
|
||||
"backend.copilot.sdk.file_ref.read_file_bytes",
|
||||
new=AsyncMock(side_effect=AssertionError("should not read file bytes")),
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=AssertionError("should not read file content")),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.file_ref.read_file_bytes",
|
||||
new=AsyncMock(side_effect=AssertionError("should not read file bytes")),
|
||||
),
|
||||
):
|
||||
result = await expand_file_refs_in_args(
|
||||
{"image": "@@agptfile:workspace://img123"},
|
||||
|
||||
@@ -8,7 +8,7 @@ Cross-mode transcript flow
|
||||
==========================
|
||||
|
||||
Both ``baseline/service.py`` (fast mode) and ``sdk/service.py`` (extended_thinking
|
||||
mode) read and write the same JSONL transcript store via
|
||||
mode) read and write the same CLI session store via
|
||||
``backend.copilot.transcript.upload_transcript`` /
|
||||
``download_transcript``.
|
||||
|
||||
@@ -250,8 +250,9 @@ class TestSdkToFastModeSwitch:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_s_baseline_loads_sdk_transcript(self):
|
||||
"""Scenario S: SDK-written transcript is accepted by baseline's load helper."""
|
||||
"""Scenario S: SDK-written CLI session is accepted by baseline's load helper."""
|
||||
from backend.copilot.baseline.service import _load_prior_transcript
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
@@ -267,33 +268,41 @@ class TestSdkToFastModeSwitch:
|
||||
sdk_transcript = builder_sdk.to_jsonl()
|
||||
|
||||
# Baseline session now has those 2 SDK messages + 1 new baseline message.
|
||||
download = TranscriptDownload(content=sdk_transcript, message_count=2)
|
||||
restore = TranscriptDownload(
|
||||
content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
|
||||
baseline_builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3, # 2 SDK + 1 new baseline
|
||||
session_messages=[
|
||||
ChatMessage(role="user", content="sdk-question"),
|
||||
ChatMessage(role="assistant", content="sdk-answer"),
|
||||
ChatMessage(role="user", content="baseline-question"),
|
||||
],
|
||||
transcript_builder=baseline_builder,
|
||||
)
|
||||
|
||||
# Transcript is valid and covers the prefix.
|
||||
# CLI session is valid and covers the prefix.
|
||||
assert covers is True
|
||||
assert dl is not None
|
||||
assert baseline_builder.entry_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_s_stale_sdk_transcript_not_loaded(self):
|
||||
"""Scenario S (stale): SDK transcript is stale — baseline does not load it.
|
||||
"""Scenario S (stale): SDK CLI session is stale — baseline does not load it.
|
||||
|
||||
If SDK mode produced more turns than the transcript captured (e.g.
|
||||
upload failed on one turn), the baseline rejects the stale transcript
|
||||
If SDK mode produced more turns than the session captured (e.g.
|
||||
upload failed on one turn), the baseline rejects the stale session
|
||||
to avoid injecting an incomplete history.
|
||||
"""
|
||||
from backend.copilot.baseline.service import _load_prior_transcript
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
@@ -306,21 +315,33 @@ class TestSdkToFastModeSwitch:
|
||||
)
|
||||
sdk_transcript = builder_sdk.to_jsonl()
|
||||
|
||||
# Transcript covers only 2 messages but session has 10 (many SDK turns).
|
||||
download = TranscriptDownload(content=sdk_transcript, message_count=2)
|
||||
# Session covers only 2 messages but session has 10 (many SDK turns).
|
||||
# With watermark=2 and 10 total messages, detect_gap will fill the gap
|
||||
# by appending messages 2..8 (positions 2 to total-2).
|
||||
restore = TranscriptDownload(
|
||||
content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
|
||||
# Build a session with 10 alternating user/assistant messages + current user
|
||||
session_messages = [
|
||||
ChatMessage(role="user" if i % 2 == 0 else "assistant", content=f"msg-{i}")
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
baseline_builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=10,
|
||||
session_messages=session_messages,
|
||||
transcript_builder=baseline_builder,
|
||||
)
|
||||
|
||||
# Stale transcript must be rejected.
|
||||
assert covers is False
|
||||
assert baseline_builder.is_empty
|
||||
# With gap filling, covers is True and gap messages are appended.
|
||||
assert covers is True
|
||||
assert dl is not None
|
||||
# 2 from transcript + 7 gap messages (positions 2..8, excluding last user turn)
|
||||
assert baseline_builder.entry_count == 9
|
||||
|
||||
@@ -255,6 +255,111 @@ async def test_build_query_no_resume_multi_message(monkeypatch):
|
||||
assert was_compacted is False # mock returns False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_session_msg_ceiling_prevents_pending_duplication():
|
||||
"""session_msg_ceiling stops pending messages from leaking into the gap.
|
||||
|
||||
Scenario: transcript covers 2 messages, session has 2 historical + 1 current
|
||||
+ 2 pending drained at turn start. Without the ceiling the gap would include
|
||||
the pending messages AND current_message already has them → duplication.
|
||||
With session_msg_ceiling=3 (pre-drain count) the gap slice is empty and
|
||||
only current_message carries the pending content.
|
||||
"""
|
||||
# session.messages after drain: [hist1, hist2, current_msg, pending1, pending2]
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="hist1"),
|
||||
ChatMessage(role="assistant", content="hist2"),
|
||||
ChatMessage(role="user", content="current msg with pending1 pending2"),
|
||||
ChatMessage(role="user", content="pending1"),
|
||||
ChatMessage(role="user", content="pending2"),
|
||||
]
|
||||
)
|
||||
# transcript covers hist1+hist2 (2 messages); pre-drain count was 3 (includes current_msg)
|
||||
result, was_compacted = await _build_query_message(
|
||||
"current msg with pending1 pending2",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=2,
|
||||
session_id="test-session",
|
||||
session_msg_ceiling=3, # len(session.messages) before drain
|
||||
)
|
||||
# Gap should be empty (transcript_msg_count == ceiling - 1), so no history prepended
|
||||
assert result == "current msg with pending1 pending2"
|
||||
assert was_compacted is False
|
||||
# Pending messages must NOT appear in gap context
|
||||
assert "pending1" not in result.split("current msg")[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_session_msg_ceiling_preserves_real_gap():
|
||||
"""session_msg_ceiling still surfaces a genuine stale-transcript gap.
|
||||
|
||||
Scenario: transcript covers 2 messages, session has 4 historical + 1 current
|
||||
+ 2 pending. Ceiling = 5 (pre-drain). Real gap = messages 2-3 (hist3, hist4).
|
||||
"""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="hist1"),
|
||||
ChatMessage(role="assistant", content="hist2"),
|
||||
ChatMessage(role="user", content="hist3"),
|
||||
ChatMessage(role="assistant", content="hist4"),
|
||||
ChatMessage(role="user", content="current"),
|
||||
ChatMessage(role="user", content="pending1"),
|
||||
ChatMessage(role="user", content="pending2"),
|
||||
]
|
||||
)
|
||||
result, was_compacted = await _build_query_message(
|
||||
"current",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=2,
|
||||
session_id="test-session",
|
||||
session_msg_ceiling=5, # pre-drain: [hist1..hist4, current]
|
||||
)
|
||||
# Gap = session.messages[2:4] = [hist3, hist4]
|
||||
assert "<conversation_history>" in result
|
||||
assert "hist3" in result
|
||||
assert "hist4" in result
|
||||
assert "Now, the user says:\ncurrent" in result
|
||||
# Pending messages must NOT appear in gap
|
||||
assert "pending1" not in result
|
||||
assert "pending2" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_session_msg_ceiling_suppresses_spurious_no_resume_fallback():
|
||||
"""session_msg_ceiling prevents the no-resume compression fallback from
|
||||
firing on the first turn of a session when pending messages inflate msg_count.
|
||||
|
||||
Scenario: fresh session (1 message) + 1 pending message drained at turn start.
|
||||
Without the ceiling: msg_count=2 > 1 → fallback triggers → pending message
|
||||
leaked into history → wrong context sent to model.
|
||||
With session_msg_ceiling=1 (pre-drain count): effective_count=1, 1 > 1 is False
|
||||
→ fallback does not trigger → current_message returned as-is.
|
||||
"""
|
||||
# session.messages after drain: [current_msg, pending_msg]
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="What is 2 plus 2?"),
|
||||
ChatMessage(role="user", content="What is 7 plus 7?"), # pending
|
||||
]
|
||||
)
|
||||
result, was_compacted = await _build_query_message(
|
||||
"What is 2 plus 2?\n\nWhat is 7 plus 7?",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
session_msg_ceiling=1, # pre-drain: only 1 message existed
|
||||
)
|
||||
# Should return current_message directly without wrapping in history context
|
||||
assert result == "What is 2 plus 2?\n\nWhat is 7 plus 7?"
|
||||
assert was_compacted is False
|
||||
# Pending question must NOT appear in a spurious history section
|
||||
assert "<conversation_history>" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_multi_message_compacted(monkeypatch):
|
||||
"""When compression actually compacts, was_compacted should be True."""
|
||||
|
||||
@@ -28,6 +28,9 @@ from backend.copilot.response_model import (
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamReasoningDelta,
|
||||
StreamReasoningEnd,
|
||||
StreamReasoningStart,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
@@ -56,9 +59,21 @@ class SDKResponseAdapter:
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_started_text = False
|
||||
self.has_ended_text = False
|
||||
self.reasoning_block_id = str(uuid.uuid4())
|
||||
self.has_started_reasoning = False
|
||||
self.has_ended_reasoning = True
|
||||
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
||||
self.resolved_tool_calls: set[str] = set()
|
||||
self.step_open = False
|
||||
# Track whether any ``TextBlock`` was emitted after the most recent
|
||||
# tool_result. Used at ``ResultMessage`` time to detect the
|
||||
# "thinking-only final turn" case — when Claude's last LLM call
|
||||
# produced only a ``ThinkingBlock`` (no text, no tool_use) the UI
|
||||
# hangs on the last tool result with a "Thought for Xs" label and
|
||||
# no response text. We synthesize a short closing line in that
|
||||
# case so the turn renders as cleanly complete.
|
||||
self._text_since_last_tool_result = False
|
||||
self._any_tool_results_seen = False
|
||||
|
||||
@property
|
||||
def has_unresolved_tool_calls(self) -> bool:
|
||||
@@ -103,18 +118,43 @@ class SDKResponseAdapter:
|
||||
for block in sdk_message.content:
|
||||
if isinstance(block, TextBlock):
|
||||
if block.text:
|
||||
# Reasoning and text are distinct UI parts; close
|
||||
# any open reasoning block before opening text so
|
||||
# the AI SDK transport doesn't merge them.
|
||||
self._end_reasoning_if_open(responses)
|
||||
self._ensure_text_started(responses)
|
||||
responses.append(
|
||||
StreamTextDelta(id=self.text_block_id, delta=block.text)
|
||||
)
|
||||
self._text_since_last_tool_result = True
|
||||
|
||||
elif isinstance(block, ThinkingBlock):
|
||||
# Thinking blocks are preserved in the transcript but
|
||||
# not streamed to the frontend — skip silently.
|
||||
pass
|
||||
# Stream extended_thinking content as a reasoning
|
||||
# block. The Vercel AI SDK's ``useChat`` transport
|
||||
# recognises ``reasoning-start`` / ``reasoning-delta``
|
||||
# / ``reasoning-end`` events and accumulates them into
|
||||
# a ``type: 'reasoning'`` UIMessage part the frontend
|
||||
# renders via ``ReasoningCollapse`` (collapsed by
|
||||
# default). We also persist the text as a
|
||||
# ``type: 'thinking'`` part in ``session.messages`` via
|
||||
# ``_format_sdk_content_blocks``, so shared / reloaded
|
||||
# sessions see the same reasoning. Without streaming
|
||||
# it live, extended_thinking turns that end
|
||||
# thinking-only left the UI stuck on "Thought for Xs"
|
||||
# with nothing rendered until a page refresh.
|
||||
if block.thinking:
|
||||
self._end_text_if_open(responses)
|
||||
self._ensure_reasoning_started(responses)
|
||||
responses.append(
|
||||
StreamReasoningDelta(
|
||||
id=self.reasoning_block_id,
|
||||
delta=block.thinking,
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(block, ToolUseBlock):
|
||||
self._end_text_if_open(responses)
|
||||
self._end_reasoning_if_open(responses)
|
||||
|
||||
# Strip MCP prefix so frontend sees "find_block"
|
||||
# instead of "mcp__copilot__find_block".
|
||||
@@ -210,16 +250,58 @@ class SDKResponseAdapter:
|
||||
resolved_in_blocks.add(parent_id)
|
||||
|
||||
self.resolved_tool_calls.update(resolved_in_blocks)
|
||||
if resolved_in_blocks:
|
||||
# A new tool_result just landed — reset the
|
||||
# "has the model emitted text since the last tool result?"
|
||||
# tracker so the thinking-only-final-turn guard at
|
||||
# ``ResultMessage`` time stays accurate.
|
||||
self._text_since_last_tool_result = False
|
||||
self._any_tool_results_seen = True
|
||||
|
||||
# Close the current step after tool results — the next
|
||||
# AssistantMessage will open a new step for the continuation.
|
||||
if self.step_open:
|
||||
self._end_reasoning_if_open(responses)
|
||||
responses.append(StreamFinishStep())
|
||||
self.step_open = False
|
||||
|
||||
elif isinstance(sdk_message, ResultMessage):
|
||||
self._flush_unresolved_tool_calls(responses)
|
||||
# Thinking-only final turn guard: when the model's last LLM
|
||||
# call after a tool result produced only a ``ThinkingBlock``
|
||||
# (no ``TextBlock``, no ``ToolUseBlock``) the UI has nothing
|
||||
# to render after the tool output — it hangs on "Thought for
|
||||
# Xs" with no response text. Synthesise a short closing line
|
||||
# so the turn visibly completes. Condition: we've seen at
|
||||
# least one tool_result AND zero TextBlocks since. The
|
||||
# prompt rule (``_USER_FOLLOW_UP_NOTE``'s closing clause)
|
||||
# asks the model to always end with text, but we can't rely
|
||||
# on it for extended_thinking / edge cases.
|
||||
if (
|
||||
self._any_tool_results_seen
|
||||
and not self._text_since_last_tool_result
|
||||
and sdk_message.subtype == "success"
|
||||
):
|
||||
# UserMessage (tool_result) closed the last step, so we must
|
||||
# open a fresh one before emitting any text — the AI SDK v5
|
||||
# transport rejects text-delta chunks that aren't wrapped in
|
||||
# start-step / finish-step.
|
||||
if not self.step_open:
|
||||
responses.append(StreamStartStep())
|
||||
self.step_open = True
|
||||
# Close any open reasoning block first — text and reasoning
|
||||
# must not interleave on the wire (AI SDK v5 maps distinct
|
||||
# start/end events to distinct UI parts).
|
||||
self._end_reasoning_if_open(responses)
|
||||
self._ensure_text_started(responses)
|
||||
responses.append(
|
||||
StreamTextDelta(
|
||||
id=self.text_block_id,
|
||||
delta="(Done — no further commentary.)",
|
||||
)
|
||||
)
|
||||
self._end_text_if_open(responses)
|
||||
self._end_reasoning_if_open(responses)
|
||||
# Close the step before finishing.
|
||||
if self.step_open:
|
||||
responses.append(StreamFinishStep())
|
||||
@@ -261,6 +343,26 @@ class SDKResponseAdapter:
|
||||
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||
self.has_ended_text = True
|
||||
|
||||
def _ensure_reasoning_started(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""Start (or restart) a reasoning block if needed.
|
||||
|
||||
Each ``ThinkingBlock`` the SDK emits gets its own streaming block
|
||||
on the wire so the frontend can render a new ``Reasoning`` part
|
||||
per LLM turn (rather than concatenating across the whole session).
|
||||
"""
|
||||
if not self.has_started_reasoning or self.has_ended_reasoning:
|
||||
if self.has_ended_reasoning:
|
||||
self.reasoning_block_id = str(uuid.uuid4())
|
||||
self.has_ended_reasoning = False
|
||||
responses.append(StreamReasoningStart(id=self.reasoning_block_id))
|
||||
self.has_started_reasoning = True
|
||||
|
||||
def _end_reasoning_if_open(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""End the current reasoning block if one is open."""
|
||||
if self.has_started_reasoning and not self.has_ended_reasoning:
|
||||
responses.append(StreamReasoningEnd(id=self.reasoning_block_id))
|
||||
self.has_ended_reasoning = True
|
||||
|
||||
def _flush_unresolved_tool_calls(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""Emit outputs for tool calls that didn't receive a UserMessage result.
|
||||
|
||||
@@ -305,7 +407,7 @@ class SDKResponseAdapter:
|
||||
self.resolved_tool_calls.add(tool_id)
|
||||
flushed = True
|
||||
logger.info(
|
||||
"[SDK] [%s] Flushed stashed output for %s " "(call %s, %d chars)",
|
||||
"[SDK] [%s] Flushed stashed output for %s (call %s, %d chars)",
|
||||
sid,
|
||||
tool_name,
|
||||
tool_id[:12],
|
||||
@@ -335,9 +437,17 @@ class SDKResponseAdapter:
|
||||
tool_id[:12],
|
||||
)
|
||||
|
||||
if flushed and self.step_open:
|
||||
responses.append(StreamFinishStep())
|
||||
self.step_open = False
|
||||
if flushed:
|
||||
# Mirror the UserMessage tool_result path: a flushed tool output is
|
||||
# still a tool_result as far as the thinking-only-final-turn guard
|
||||
# is concerned. Without this, a turn whose ONLY tool outputs come
|
||||
# from the flush path (SDK built-ins like WebSearch) would miss
|
||||
# the fallback synthesis if the model then produced no text.
|
||||
self._text_since_last_tool_result = False
|
||||
self._any_tool_results_seen = True
|
||||
if self.step_open:
|
||||
responses.append(StreamFinishStep())
|
||||
self.step_open = False
|
||||
|
||||
|
||||
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
||||
|
||||
@@ -8,6 +8,7 @@ from claude_agent_sdk import (
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
@@ -19,6 +20,7 @@ from backend.copilot.response_model import (
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamReasoningDelta,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
@@ -251,6 +253,200 @@ def test_result_success_emits_finish_step_and_finish():
|
||||
assert isinstance(results[2], StreamFinish)
|
||||
|
||||
|
||||
# -- Reasoning streaming -----------------------------------------------------
|
||||
|
||||
|
||||
def test_thinking_block_streams_as_reasoning():
|
||||
"""ThinkingBlock content streams as StreamReasoningDelta so the
|
||||
frontend renders it via the ``Reasoning`` part (collapsed by
|
||||
default) instead of dropping it silently."""
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(
|
||||
content=[
|
||||
ThinkingBlock(thinking="planning step 1", signature="sig"),
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
# Step + ReasoningStart + ReasoningDelta
|
||||
types = [type(r).__name__ for r in results]
|
||||
assert "StreamReasoningStart" in types
|
||||
assert any(
|
||||
isinstance(r, StreamReasoningDelta) and r.delta == "planning step 1"
|
||||
for r in results
|
||||
)
|
||||
|
||||
|
||||
def test_text_after_thinking_closes_reasoning_and_opens_text():
|
||||
"""Reasoning and text are distinct UI parts — opening text must
|
||||
emit ``ReasoningEnd`` first so the AI SDK transport doesn't merge
|
||||
them into the same ``Reasoning`` part."""
|
||||
adapter = _adapter()
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[ThinkingBlock(thinking="warming up", signature="sig")],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
results = adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="hello")], model="test")
|
||||
)
|
||||
types = [type(r).__name__ for r in results]
|
||||
# ReasoningEnd must come before TextStart
|
||||
re_idx = types.index("StreamReasoningEnd")
|
||||
ts_idx = types.index("StreamTextStart")
|
||||
assert re_idx < ts_idx
|
||||
|
||||
|
||||
def test_tool_use_after_thinking_closes_reasoning():
|
||||
"""Opening a tool also closes an open reasoning block."""
|
||||
adapter = _adapter()
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[ThinkingBlock(thinking="let me search", signature="sig")],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
results = adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_block", input={})
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
types = [type(r).__name__ for r in results]
|
||||
assert types.index("StreamReasoningEnd") < types.index("StreamToolInputStart")
|
||||
|
||||
|
||||
def test_empty_thinking_block_is_ignored():
|
||||
"""A ThinkingBlock with empty content shouldn't emit anything."""
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(
|
||||
content=[ThinkingBlock(thinking="", signature="sig")],
|
||||
model="test",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
# Only the StepStart fires — no reasoning events.
|
||||
assert [type(r).__name__ for r in results] == ["StreamStartStep"]
|
||||
|
||||
|
||||
def test_result_success_synthesizes_fallback_text_when_final_turn_is_thinking_only():
|
||||
"""If the model's last LLM call after a tool_result produced only a
|
||||
ThinkingBlock (no TextBlock), the UI would hang on the tool output
|
||||
with no response text. The adapter should inject a short closing
|
||||
line before ``StreamFinish`` so the turn visibly completes."""
|
||||
adapter = _adapter()
|
||||
|
||||
# Tool use + tool_result (simulates the tool round).
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_block", input={}),
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
adapter.convert_message(
|
||||
UserMessage(
|
||||
content=[
|
||||
ToolResultBlock(tool_use_id="t1", content="result", is_error=False)
|
||||
],
|
||||
parent_tool_use_id=None,
|
||||
)
|
||||
)
|
||||
|
||||
# Model's "final turn" after tool_result is thinking-only. This test
|
||||
# simulates the *degenerate* case where the SDK never surfaces an
|
||||
# AssistantMessage carrying the ThinkingBlock at all (not even the
|
||||
# streamed reasoning events) before ResultMessage — only the tool_result
|
||||
# has arrived. The fallback guard should still synthesize closing text.
|
||||
msg = ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=4,
|
||||
session_id="s1",
|
||||
result="",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
|
||||
# Fallback text should be injected before the finish events.
|
||||
text_deltas = [r for r in results if isinstance(r, StreamTextDelta)]
|
||||
assert len(text_deltas) == 1, "should synthesize exactly one fallback text"
|
||||
assert text_deltas[0].delta.strip() # non-empty
|
||||
assert isinstance(results[-1], StreamFinish)
|
||||
|
||||
|
||||
def test_result_success_does_not_synthesize_when_text_already_emitted():
|
||||
"""Guard: do NOT synthesize when the model DID emit closing text
|
||||
after the last tool result — the fallback is only for the silent
|
||||
thinking-only case."""
|
||||
adapter = _adapter()
|
||||
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_block", input={})
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
adapter.convert_message(
|
||||
UserMessage(
|
||||
content=[
|
||||
ToolResultBlock(tool_use_id="t1", content="result", is_error=False)
|
||||
],
|
||||
parent_tool_use_id=None,
|
||||
)
|
||||
)
|
||||
# Model responds with actual text after the tool result.
|
||||
adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="all done")], model="test")
|
||||
)
|
||||
|
||||
msg = ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=4,
|
||||
session_id="s1",
|
||||
result="all done",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
|
||||
# No fallback — the only TextDelta came from the previous
|
||||
# AssistantMessage call, not from ResultMessage's synthesis.
|
||||
text_deltas = [r for r in results if isinstance(r, StreamTextDelta)]
|
||||
assert text_deltas == []
|
||||
|
||||
|
||||
def test_result_success_does_not_synthesize_when_no_tools_ran():
|
||||
"""Guard: no tool_results seen ⇒ no fallback. Pure-text turns with
|
||||
no tools legitimately produce text-only responses through normal
|
||||
AssistantMessage events; we don't need a fallback there."""
|
||||
adapter = _adapter()
|
||||
|
||||
adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="hello")], model="test")
|
||||
)
|
||||
|
||||
msg = ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
result="hello",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
text_deltas = [r for r in results if isinstance(r, StreamTextDelta)]
|
||||
assert text_deltas == []
|
||||
|
||||
|
||||
def test_result_error_emits_error_and_finish():
|
||||
adapter = _adapter()
|
||||
msg = ResultMessage(
|
||||
@@ -426,6 +622,13 @@ def test_flush_unresolved_at_result_message():
|
||||
"StreamToolInputAvailable",
|
||||
"StreamToolOutputAvailable", # flushed with empty output
|
||||
"StreamFinishStep", # step closed by flush
|
||||
# Flush marks a tool_result as seen, so the thinking-only-final-turn
|
||||
# guard at ResultMessage time synthesizes a closing text delta.
|
||||
"StreamStartStep",
|
||||
"StreamTextStart",
|
||||
"StreamTextDelta",
|
||||
"StreamTextEnd",
|
||||
"StreamFinishStep",
|
||||
"StreamFinish",
|
||||
]
|
||||
# The flushed output should be empty (no stash available)
|
||||
|
||||
@@ -27,6 +27,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from backend.copilot.transcript import (
|
||||
TranscriptDownload,
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
@@ -999,14 +1000,15 @@ def _make_sdk_patches(
|
||||
f"{_SVC}.download_transcript",
|
||||
dict(
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(content=original_transcript, message_count=2),
|
||||
return_value=TranscriptDownload(
|
||||
content=original_transcript.encode("utf-8"),
|
||||
message_count=2,
|
||||
mode="sdk",
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
f"{_SVC}.restore_cli_session",
|
||||
dict(new_callable=AsyncMock, return_value=True),
|
||||
),
|
||||
(f"{_SVC}.upload_cli_session", dict(new_callable=AsyncMock)),
|
||||
(f"{_SVC}.strip_for_upload", dict(return_value=original_transcript)),
|
||||
(f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)),
|
||||
(f"{_SVC}.validate_transcript", dict(return_value=True)),
|
||||
(
|
||||
f"{_SVC}.compact_transcript",
|
||||
@@ -1037,8 +1039,13 @@ def _make_sdk_patches(
|
||||
claude_agent_fallback_model=None,
|
||||
),
|
||||
),
|
||||
(f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)),
|
||||
(f"{_SVC}.get_user_tier", dict(new_callable=AsyncMock, return_value=None)),
|
||||
# Stub pending-message drain so retry tests don't hit Redis.
|
||||
# Returns an empty list → no mid-turn injection happens.
|
||||
(
|
||||
f"{_SVC}.drain_pending_safe",
|
||||
dict(new_callable=AsyncMock, return_value=[]),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -1914,14 +1921,14 @@ class TestStreamChatCompletionRetryIntegration:
|
||||
compacted_transcript=None,
|
||||
client_side_effect=_client_factory,
|
||||
)
|
||||
# Override restore_cli_session to return False (CLI native session unavailable)
|
||||
# Override download_transcript to return None (CLI native session unavailable)
|
||||
patches = [
|
||||
(
|
||||
(
|
||||
f"{_SVC}.restore_cli_session",
|
||||
dict(new_callable=AsyncMock, return_value=False),
|
||||
f"{_SVC}.download_transcript",
|
||||
dict(new_callable=AsyncMock, return_value=None),
|
||||
)
|
||||
if p[0] == f"{_SVC}.restore_cli_session"
|
||||
if p[0] == f"{_SVC}.download_transcript"
|
||||
else p
|
||||
)
|
||||
for p in patches
|
||||
@@ -1944,7 +1951,7 @@ class TestStreamChatCompletionRetryIntegration:
|
||||
# captured_options holds {"options": ClaudeAgentOptions}, so check
|
||||
# the attribute directly rather than dict keys.
|
||||
assert not getattr(captured_options.get("options"), "resume", None), (
|
||||
f"--resume was set even though restore_cli_session returned False: "
|
||||
f"--resume was set even though download_transcript returned None: "
|
||||
f"{captured_options}"
|
||||
)
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@@ -94,21 +94,23 @@ def test_agent_options_accepts_required_fields():
|
||||
def test_agent_options_accepts_system_prompt_preset_with_exclude_dynamic_sections():
|
||||
"""Verify ClaudeAgentOptions accepts the exact preset dict _build_system_prompt_value produces.
|
||||
|
||||
The production code always includes ``exclude_dynamic_sections=True`` in the preset
|
||||
dict. This compat test mirrors that exact shape so any SDK version that starts
|
||||
rejecting unknown keys will be caught here rather than at runtime.
|
||||
The Turn 1 (non-resume) code path includes ``exclude_dynamic_sections=True`` in
|
||||
the preset dict for cross-user caching. This compat test mirrors that exact
|
||||
shape so any SDK version that starts rejecting unknown keys will be caught
|
||||
here rather than at runtime.
|
||||
"""
|
||||
from claude_agent_sdk import ClaudeAgentOptions
|
||||
from claude_agent_sdk.types import SystemPromptPreset
|
||||
|
||||
from .service import _build_system_prompt_value
|
||||
|
||||
# Call the production helper directly so this test is tied to the real
|
||||
# dict shape rather than a hand-rolled copy.
|
||||
preset = _build_system_prompt_value("custom system prompt", cross_user_cache=True)
|
||||
assert isinstance(
|
||||
preset, dict
|
||||
), "_build_system_prompt_value must return a dict when caching is on"
|
||||
assert preset.get("exclude_dynamic_sections") is True, (
|
||||
"Turn 1 must strip dynamic sections to keep the prefix cacheable " "cross-user"
|
||||
)
|
||||
|
||||
sdk_preset = cast(SystemPromptPreset, preset)
|
||||
opts = ClaudeAgentOptions(system_prompt=sdk_preset)
|
||||
@@ -116,8 +118,9 @@ def test_agent_options_accepts_system_prompt_preset_with_exclude_dynamic_section
|
||||
|
||||
|
||||
def test_build_system_prompt_value_returns_plain_string_when_cross_user_cache_off():
|
||||
"""When cross_user_cache=False (e.g. on --resume turns), the helper must return
|
||||
a plain string so the preset+resume crash is avoided."""
|
||||
"""When cross_user_cache=False (feature flag disabled globally), the
|
||||
helper returns a plain string; the CLI will receive --system-prompt
|
||||
(replace-mode) and skip the preset entirely."""
|
||||
from .service import _build_system_prompt_value
|
||||
|
||||
result = _build_system_prompt_value("my prompt", cross_user_cache=False)
|
||||
@@ -262,6 +265,12 @@ _KNOWN_GOOD_BUNDLED_CLI_VERSIONS: frozenset[str] = frozenset(
|
||||
"2.1.97", # claude-agent-sdk 0.1.58 -- OpenRouter-safe only with
|
||||
# CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS=1 (injected by
|
||||
# build_sdk_env() in env.py).
|
||||
"2.1.116", # claude-agent-sdk 0.1.64 -- first bundled version that
|
||||
# fixes the --resume + excludeDynamicSections=True crash
|
||||
# (introduced in 2.1.98), unlocking cross-user prompt
|
||||
# cache reads on every resumed SDK turn. Still requires
|
||||
# CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS=1. Verified
|
||||
# OpenRouter-safe via cli_openrouter_compat_test.py.
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -10,7 +10,12 @@ import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from backend.copilot.context import is_allowed_local_path, is_sdk_tool_path
|
||||
from backend.copilot.context import (
|
||||
get_execution_context,
|
||||
is_allowed_local_path,
|
||||
is_sdk_tool_path,
|
||||
)
|
||||
from backend.copilot.pending_messages import drain_and_format_for_injection
|
||||
|
||||
from .tool_adapter import (
|
||||
BLOCKED_TOOLS,
|
||||
@@ -327,6 +332,30 @@ def create_security_hooks(
|
||||
tool_name,
|
||||
)
|
||||
|
||||
# Mid-turn drain: after ANY tool finishes (MCP or built-in), pull
|
||||
# any queued user follow-up messages and attach them to the
|
||||
# tool_result as ``additionalContext``. This is the
|
||||
# protocol-legal mid-turn injection slot — Claude reads the
|
||||
# follow-up on the next LLM round without starting a new turn.
|
||||
# The drain helper also stashes a persist-queue copy so
|
||||
# ``sdk/service.py`` can append a matching user row to the UI.
|
||||
_, session = get_execution_context()
|
||||
followup = ""
|
||||
if session is not None and session.session_id:
|
||||
followup = await drain_and_format_for_injection(
|
||||
session.session_id,
|
||||
log_prefix="[SDK][PostToolUse]",
|
||||
)
|
||||
if followup:
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
{
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PostToolUse",
|
||||
"additionalContext": followup,
|
||||
}
|
||||
},
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def post_tool_failure_hook(
|
||||
@@ -365,7 +394,7 @@ def create_security_hooks(
|
||||
trigger = _sanitize(str(input_data.get("trigger", "auto")), max_len=50)
|
||||
# Sanitize untrusted input: strip control chars for logging AND
|
||||
# for the value passed downstream. read_compacted_entries()
|
||||
# validates against _projects_base() as defence-in-depth, but
|
||||
# validates against projects_base() as defence-in-depth, but
|
||||
# sanitizing here prevents log injection and rejects obviously
|
||||
# malformed paths early.
|
||||
transcript_path = _sanitize(
|
||||
|
||||
@@ -699,3 +699,160 @@ async def test_subagent_hooks_sanitize_inputs(_subagent_hooks, caplog):
|
||||
assert "\u202a" not in record.message
|
||||
assert "\u200b" not in record.message
|
||||
assert "/tmp/maliciouspath" in caplog.text
|
||||
|
||||
|
||||
# -- PostToolUse: mid-turn pending-message drain ------------------------------
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_tool_use_injects_followup_additional_context(
|
||||
monkeypatch,
|
||||
):
|
||||
"""Queued messages drain into ``additionalContext`` for any tool."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.copilot import context as ctx_mod
|
||||
from backend.copilot import pending_messages as pm_module
|
||||
|
||||
session = MagicMock()
|
||||
session.session_id = "sess-post-inject"
|
||||
ctx_mod.set_execution_context(
|
||||
user_id="u1",
|
||||
session=session,
|
||||
sandbox=None,
|
||||
sdk_cwd=SDK_CWD,
|
||||
)
|
||||
|
||||
async def fake_drain(_session_id: str):
|
||||
assert _session_id == "sess-post-inject"
|
||||
return [pm_module.PendingMessage(content="please also do X")]
|
||||
|
||||
async def fake_stash(_session_id, _messages):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.pending_messages.drain_pending_messages", fake_drain
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.pending_messages.stash_pending_for_persist", fake_stash
|
||||
)
|
||||
|
||||
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
|
||||
post = hooks["PostToolUse"][0].hooks[0]
|
||||
|
||||
result = await post(
|
||||
{
|
||||
"tool_name": "WebSearch", # built-in — the path the old wrapper missed
|
||||
"tool_response": "search results here",
|
||||
},
|
||||
tool_use_id="tu-web-1",
|
||||
context={},
|
||||
)
|
||||
|
||||
injected = result.get("hookSpecificOutput", {})
|
||||
assert injected.get("hookEventName") == "PostToolUse"
|
||||
assert "<user_follow_up>" in injected.get("additionalContext", "")
|
||||
assert "please also do X" in injected.get("additionalContext", "")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_tool_use_no_pending_returns_empty(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.copilot import context as ctx_mod
|
||||
|
||||
session = MagicMock()
|
||||
session.session_id = "sess-post-empty"
|
||||
ctx_mod.set_execution_context(
|
||||
user_id="u1", session=session, sandbox=None, sdk_cwd=SDK_CWD
|
||||
)
|
||||
|
||||
async def fake_drain(_session_id: str):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.pending_messages.drain_pending_messages", fake_drain
|
||||
)
|
||||
|
||||
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
|
||||
post = hooks["PostToolUse"][0].hooks[0]
|
||||
|
||||
result = await post(
|
||||
{"tool_name": "mcp__copilot__run_block", "tool_response": "ok"},
|
||||
tool_use_id="tu-mcp-1",
|
||||
context={},
|
||||
)
|
||||
|
||||
# No additionalContext means Claude gets the tool_result verbatim.
|
||||
assert "hookSpecificOutput" not in result
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_tool_use_drain_failure_returns_empty(monkeypatch):
|
||||
"""A Redis blip must not corrupt the hook response."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.copilot import context as ctx_mod
|
||||
|
||||
session = MagicMock()
|
||||
session.session_id = "sess-post-fail"
|
||||
ctx_mod.set_execution_context(
|
||||
user_id="u1", session=session, sandbox=None, sdk_cwd=SDK_CWD
|
||||
)
|
||||
|
||||
async def failing_drain(_session_id: str):
|
||||
raise RuntimeError("redis down")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.pending_messages.drain_pending_messages", failing_drain
|
||||
)
|
||||
|
||||
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
|
||||
post = hooks["PostToolUse"][0].hooks[0]
|
||||
|
||||
result = await post(
|
||||
{"tool_name": "Read", "tool_response": "file body"},
|
||||
tool_use_id="tu-read-1",
|
||||
context={},
|
||||
)
|
||||
|
||||
assert "hookSpecificOutput" not in result
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_tool_use_no_session_skips_drain(monkeypatch):
|
||||
from backend.copilot import context as ctx_mod
|
||||
|
||||
ctx_mod.set_execution_context(
|
||||
user_id=None,
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
sdk_cwd=SDK_CWD,
|
||||
)
|
||||
|
||||
drain_called = False
|
||||
|
||||
async def fake_drain(_session_id: str):
|
||||
nonlocal drain_called
|
||||
drain_called = True
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.pending_messages.drain_pending_messages", fake_drain
|
||||
)
|
||||
|
||||
hooks = create_security_hooks(user_id=None, sdk_cwd=SDK_CWD, max_subtasks=2)
|
||||
post = hooks["PostToolUse"][0].hooks[0]
|
||||
|
||||
result = await post(
|
||||
{"tool_name": "WebSearch", "tool_response": "x"},
|
||||
tool_use_id="tu-x",
|
||||
context={},
|
||||
)
|
||||
|
||||
assert drain_called is False
|
||||
assert "hookSpecificOutput" not in result
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -22,6 +22,7 @@ from .service import (
|
||||
_iter_sdk_messages,
|
||||
_normalize_model_name,
|
||||
_reduce_context,
|
||||
_restore_cli_session_for_turn,
|
||||
_TokenUsage,
|
||||
)
|
||||
|
||||
@@ -615,3 +616,340 @@ class TestSdkSessionIdSelection:
|
||||
)
|
||||
assert retry.get("resume") == self.SESSION_ID
|
||||
assert "session_id" not in retry
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _restore_cli_session_for_turn — mode check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRestoreCliSessionModeCheck:
|
||||
"""SDK skips --resume when the transcript was written by the baseline mode."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_baseline_mode_transcript_skips_gcs_content(self, tmp_path):
|
||||
"""A transcript with mode='baseline' must not be used as the --resume source.
|
||||
|
||||
The mode check discards the GCS baseline content and falls back to DB
|
||||
reconstruction from session.messages instead.
|
||||
"""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.transcript import TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
session = ChatSession(
|
||||
session_id="test-session",
|
||||
user_id="user-1",
|
||||
messages=[
|
||||
ChatMessage(role="user", content="hello-unique-marker"),
|
||||
ChatMessage(role="assistant", content="world-unique-marker"),
|
||||
ChatMessage(role="user", content="follow up"),
|
||||
],
|
||||
title="test",
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
builder = TranscriptBuilder()
|
||||
# Baseline content with a sentinel that must NOT appear in the final transcript
|
||||
baseline_restore = TranscriptDownload(
|
||||
content=b'{"type":"user","uuid":"bad-uuid","message":{"role":"user","content":"BASELINE_SENTINEL"}}\n',
|
||||
message_count=1,
|
||||
mode="baseline",
|
||||
)
|
||||
|
||||
import backend.copilot.sdk.service as _svc_mod
|
||||
|
||||
download_mock = AsyncMock(return_value=baseline_restore)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.download_transcript",
|
||||
new=download_mock,
|
||||
),
|
||||
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
|
||||
):
|
||||
result = await _restore_cli_session_for_turn(
|
||||
user_id="user-1",
|
||||
session_id="test-session",
|
||||
session=session,
|
||||
sdk_cwd=str(tmp_path),
|
||||
transcript_builder=builder,
|
||||
log_prefix="[Test]",
|
||||
)
|
||||
|
||||
# download_transcript was called (attempted GCS restore)
|
||||
download_mock.assert_awaited_once()
|
||||
# use_resume must be False — baseline transcripts cannot be used with --resume
|
||||
assert result.use_resume is False
|
||||
# context_messages must be populated — new behaviour uses transcript content + gap
|
||||
# instead of full DB reconstruction.
|
||||
assert result.context_messages is not None
|
||||
# The baseline transcript has 1 user message (BASELINE_SENTINEL).
|
||||
# Watermark=1 but position 0 is 'user', not 'assistant', so detect_gap returns [].
|
||||
# Result: 1 message from transcript, no gap.
|
||||
assert len(result.context_messages) == 1
|
||||
assert "BASELINE_SENTINEL" in (result.context_messages[0].content or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sdk_mode_transcript_allows_resume(self, tmp_path):
|
||||
"""A valid SDK-written transcript is accepted for --resume."""
|
||||
import json as stdlib_json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
lines = [
|
||||
stdlib_json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "uid-0",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
}
|
||||
),
|
||||
stdlib_json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "uid-1",
|
||||
"parentUuid": "uid-0",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_1",
|
||||
"model": "test",
|
||||
"type": "message",
|
||||
"stop_reason": STOP_REASON_END_TURN,
|
||||
"content": [{"type": "text", "text": "hello"}],
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
content = ("\n".join(lines) + "\n").encode("utf-8")
|
||||
|
||||
session = ChatSession(
|
||||
session_id="test-session",
|
||||
user_id="user-1",
|
||||
messages=[
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(role="assistant", content="hello"),
|
||||
ChatMessage(role="user", content="follow up"),
|
||||
],
|
||||
title="test",
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
builder = TranscriptBuilder()
|
||||
sdk_restore = TranscriptDownload(
|
||||
content=content,
|
||||
message_count=2,
|
||||
mode="sdk",
|
||||
)
|
||||
|
||||
import backend.copilot.sdk.service as _svc_mod
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.download_transcript",
|
||||
new=AsyncMock(return_value=sdk_restore),
|
||||
),
|
||||
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
|
||||
):
|
||||
result = await _restore_cli_session_for_turn(
|
||||
user_id="user-1",
|
||||
session_id="test-session",
|
||||
session=session,
|
||||
sdk_cwd=str(tmp_path),
|
||||
transcript_builder=builder,
|
||||
log_prefix="[Test]",
|
||||
)
|
||||
|
||||
assert result.use_resume is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_baseline_mode_context_messages_from_transcript_content(
|
||||
self, tmp_path
|
||||
):
|
||||
"""mode='baseline' → context_messages populated from transcript content + gap.
|
||||
|
||||
When a baseline-mode transcript exists, extract_context_messages converts
|
||||
the JSONL content to ChatMessage objects and returns them in context_messages.
|
||||
use_resume must remain False.
|
||||
"""
|
||||
import json as stdlib_json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
# Build a minimal valid JSONL transcript with 2 messages
|
||||
lines = [
|
||||
stdlib_json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "uid-0",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "TRANSCRIPT_USER"},
|
||||
}
|
||||
),
|
||||
stdlib_json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "uid-1",
|
||||
"parentUuid": "uid-0",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_1",
|
||||
"model": "test",
|
||||
"type": "message",
|
||||
"stop_reason": STOP_REASON_END_TURN,
|
||||
"content": [{"type": "text", "text": "TRANSCRIPT_ASSISTANT"}],
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
content = ("\n".join(lines) + "\n").encode("utf-8")
|
||||
|
||||
session = ChatSession(
|
||||
session_id="test-session",
|
||||
user_id="user-1",
|
||||
messages=[
|
||||
ChatMessage(role="user", content="DB_USER"),
|
||||
ChatMessage(role="assistant", content="DB_ASSISTANT"),
|
||||
ChatMessage(role="user", content="current turn"),
|
||||
],
|
||||
title="test",
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
builder = TranscriptBuilder()
|
||||
baseline_restore = TranscriptDownload(
|
||||
content=content,
|
||||
message_count=2,
|
||||
mode="baseline",
|
||||
)
|
||||
|
||||
import backend.copilot.sdk.service as _svc_mod
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.download_transcript",
|
||||
new=AsyncMock(return_value=baseline_restore),
|
||||
),
|
||||
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
|
||||
):
|
||||
result = await _restore_cli_session_for_turn(
|
||||
user_id="user-1",
|
||||
session_id="test-session",
|
||||
session=session,
|
||||
sdk_cwd=str(tmp_path),
|
||||
transcript_builder=builder,
|
||||
log_prefix="[Test]",
|
||||
)
|
||||
|
||||
assert result.use_resume is False
|
||||
assert result.context_messages is not None
|
||||
# Transcript content has 2 messages, no gap (watermark=2, session prior=2)
|
||||
assert len(result.context_messages) == 2
|
||||
assert result.context_messages[0].role == "user"
|
||||
assert result.context_messages[1].role == "assistant"
|
||||
assert "TRANSCRIPT_ASSISTANT" in (result.context_messages[1].content or "")
|
||||
# transcript_content must be non-empty so the _seed_transcript guard in
|
||||
# stream_chat_completion_sdk skips DB reconstruction (which would duplicate
|
||||
# builder entries since load_previous appends).
|
||||
assert result.transcript_content != ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_baseline_mode_gap_present_context_includes_gap(self, tmp_path):
|
||||
"""mode='baseline' + gap → context_messages includes transcript msgs and gap."""
|
||||
import json as stdlib_json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
# Transcript covers only 2 messages; session has 4 prior + current turn
|
||||
lines = [
|
||||
stdlib_json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "uid-0",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "TRANSCRIPT_USER_0"},
|
||||
}
|
||||
),
|
||||
stdlib_json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "uid-1",
|
||||
"parentUuid": "uid-0",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_1",
|
||||
"model": "test",
|
||||
"type": "message",
|
||||
"stop_reason": STOP_REASON_END_TURN,
|
||||
"content": [{"type": "text", "text": "TRANSCRIPT_ASSISTANT_1"}],
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
content = ("\n".join(lines) + "\n").encode("utf-8")
|
||||
|
||||
session = ChatSession(
|
||||
session_id="test-session",
|
||||
user_id="user-1",
|
||||
messages=[
|
||||
ChatMessage(role="user", content="DB_USER_0"),
|
||||
ChatMessage(role="assistant", content="DB_ASSISTANT_1"),
|
||||
ChatMessage(role="user", content="GAP_USER_2"),
|
||||
ChatMessage(role="assistant", content="GAP_ASSISTANT_3"),
|
||||
ChatMessage(role="user", content="current turn"),
|
||||
],
|
||||
title="test",
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
builder = TranscriptBuilder()
|
||||
baseline_restore = TranscriptDownload(
|
||||
content=content,
|
||||
message_count=2, # watermark=2; session has 4 prior → gap of 2
|
||||
mode="baseline",
|
||||
)
|
||||
|
||||
import backend.copilot.sdk.service as _svc_mod
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.download_transcript",
|
||||
new=AsyncMock(return_value=baseline_restore),
|
||||
),
|
||||
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
|
||||
):
|
||||
result = await _restore_cli_session_for_turn(
|
||||
user_id="user-1",
|
||||
session_id="test-session",
|
||||
session=session,
|
||||
sdk_cwd=str(tmp_path),
|
||||
transcript_builder=builder,
|
||||
log_prefix="[Test]",
|
||||
)
|
||||
|
||||
assert result.use_resume is False
|
||||
assert result.context_messages is not None
|
||||
# 2 from transcript + 2 gap messages = 4 total
|
||||
assert len(result.context_messages) == 4
|
||||
roles = [m.role for m in result.context_messages]
|
||||
assert roles == ["user", "assistant", "user", "assistant"]
|
||||
# Gap messages come from DB (ChatMessage objects)
|
||||
gap_user = result.context_messages[2]
|
||||
gap_asst = result.context_messages[3]
|
||||
assert gap_user.content == "GAP_USER_2"
|
||||
assert gap_asst.content == "GAP_ASSISTANT_3"
|
||||
|
||||
@@ -11,6 +11,7 @@ import pytest
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
from .service import (
|
||||
_IDLE_TIMEOUT_SECONDS,
|
||||
_build_system_prompt_value,
|
||||
_is_sdk_disconnect_error,
|
||||
_normalize_model_name,
|
||||
@@ -176,70 +177,18 @@ class TestPromptSupplement:
|
||||
assert "## Tool notes" in local_supplement
|
||||
assert "## Tool notes" in e2b_supplement
|
||||
|
||||
def test_baseline_supplement_includes_tool_docs(self):
|
||||
"""Baseline mode MUST include tool documentation (direct API needs it)."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
def test_baseline_supplement_has_shared_notes_no_tool_list(self):
|
||||
"""Baseline now relies on the OpenAI tools array for schemas and only
|
||||
appends SHARED_TOOL_NOTES (workflow rules not present in any schema).
|
||||
The old auto-generated ``## AVAILABLE TOOLS`` list is gone — it was
|
||||
~4.3K tokens of pure duplication of the tools array."""
|
||||
from backend.copilot.prompting import SHARED_TOOL_NOTES
|
||||
|
||||
supplement = get_baseline_supplement()
|
||||
|
||||
# MUST have tool list section
|
||||
assert "## AVAILABLE TOOLS" in supplement
|
||||
|
||||
# Should NOT have environment-specific notes (SDK-only)
|
||||
assert "## Tool notes" not in supplement
|
||||
|
||||
def test_baseline_supplement_includes_key_tools(self):
|
||||
"""Baseline supplement should document all essential tools."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Core agent workflow tools (always available)
|
||||
assert "`create_agent`" in docs
|
||||
assert "`run_agent`" in docs
|
||||
assert "`find_library_agent`" in docs
|
||||
assert "`edit_agent`" in docs
|
||||
|
||||
# MCP integration (always available)
|
||||
assert "`run_mcp_tool`" in docs
|
||||
|
||||
# Folder management (always available)
|
||||
assert "`create_folder`" in docs
|
||||
|
||||
# Browser tools only if available (Playwright may not be installed in CI)
|
||||
if (
|
||||
TOOL_REGISTRY.get("browser_navigate")
|
||||
and TOOL_REGISTRY["browser_navigate"].is_available
|
||||
):
|
||||
assert "`browser_navigate`" in docs
|
||||
|
||||
def test_baseline_supplement_includes_workflows(self):
|
||||
"""Baseline supplement should include workflow guidance in tool descriptions."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Workflows are now in individual tool descriptions (not separate sections)
|
||||
# Check that key workflow concepts appear in tool descriptions
|
||||
assert "agent_json" in docs or "find_block" in docs
|
||||
assert "run_mcp_tool" in docs
|
||||
|
||||
def test_baseline_supplement_completeness(self):
|
||||
"""All available tools from TOOL_REGISTRY should appear in baseline supplement."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Verify each available registered tool is documented
|
||||
# (matches _generate_tool_documentation which filters by is_available)
|
||||
for tool_name, tool in TOOL_REGISTRY.items():
|
||||
if not tool.is_available:
|
||||
continue
|
||||
assert (
|
||||
f"`{tool_name}`" in docs
|
||||
), f"Tool '{tool_name}' missing from baseline supplement"
|
||||
assert "## AVAILABLE TOOLS" not in SHARED_TOOL_NOTES
|
||||
# Keep the high-value workflow rules that are NOT in any tool schema.
|
||||
assert "@@agptfile:" in SHARED_TOOL_NOTES
|
||||
assert "Tool Discovery Priority" in SHARED_TOOL_NOTES
|
||||
assert "run_sub_session" in SHARED_TOOL_NOTES
|
||||
|
||||
def test_pause_task_scheduled_before_transcript_upload(self):
|
||||
"""Pause is scheduled as a background task before transcript upload begins.
|
||||
@@ -283,21 +232,6 @@ class TestPromptSupplement:
|
||||
# concurrently during upload's first yield. The ordering guarantee is
|
||||
# that create_task is CALLED before upload is AWAITED (see source order).
|
||||
|
||||
def test_baseline_supplement_no_duplicate_tools(self):
|
||||
"""No tool should appear multiple times in baseline supplement."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Count occurrences of each available tool in the entire supplement
|
||||
for tool_name, tool in TOOL_REGISTRY.items():
|
||||
if not tool.is_available:
|
||||
continue
|
||||
# Count how many times this tool appears as a bullet point
|
||||
count = docs.count(f"- **`{tool_name}`**")
|
||||
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cleanup_sdk_tool_results — orchestration + rate-limiting
|
||||
@@ -699,6 +633,17 @@ class TestSystemPromptPreset:
|
||||
assert result["append"] == ""
|
||||
assert result["exclude_dynamic_sections"] is True
|
||||
|
||||
def test_resume_and_fresh_share_the_same_static_prefix(self):
|
||||
"""Every turn (fresh + --resume) must emit the same preset dict
|
||||
so the cross-user cache prefix match works on all turns. This
|
||||
relies on CLI ≥ 2.1.98 (installed in the Docker image); older
|
||||
CLIs would crash on --resume + excludeDynamicSections=True."""
|
||||
fresh = _build_system_prompt_value("sys", cross_user_cache=True)
|
||||
resumed = _build_system_prompt_value("sys", cross_user_cache=True)
|
||||
assert fresh == resumed
|
||||
assert isinstance(fresh, dict)
|
||||
assert fresh.get("exclude_dynamic_sections") is True
|
||||
|
||||
def test_default_config_is_enabled(self, _clean_config_env):
|
||||
"""The default value for claude_agent_cross_user_prompt_cache is True."""
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
@@ -719,3 +664,13 @@ class TestSystemPromptPreset:
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
assert cfg.claude_agent_cross_user_prompt_cache is False
|
||||
|
||||
|
||||
class TestIdleTimeoutConstant:
|
||||
"""SECRT-2247: long-running work now uses async start+poll pattern
|
||||
(run_sub_session / run_agent), so no single MCP tool call ever blocks
|
||||
the stream close to the idle limit. The plain 10-min cap from the
|
||||
original code is restored."""
|
||||
|
||||
def test_idle_timeout_is_10_min(self):
|
||||
assert _IDLE_TIMEOUT_SECONDS == 10 * 60
|
||||
|
||||
@@ -19,9 +19,11 @@ from __future__ import annotations
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.copilot.constants import STOPPED_BY_USER_MARKER
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.response_model import StreamStartStep, StreamTextDelta
|
||||
from backend.copilot.sdk.service import _dispatch_response, _StreamAccumulator
|
||||
from backend.copilot.session_cleanup import prune_orphan_tool_calls
|
||||
|
||||
_NOW = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
@@ -215,3 +217,183 @@ class TestPreCreateAssistantMessage:
|
||||
_simulate_pre_create(acc, ctx)
|
||||
|
||||
assert len(ctx.session.messages) == 0
|
||||
|
||||
|
||||
class TestPruneOrphanToolCalls:
|
||||
"""A Stop mid-tool-call leaves the session ending on an assistant row whose
|
||||
``tool_calls`` have no matching ``role="tool"`` row. Unless pruned before
|
||||
the next turn, the ``--resume`` transcript would hand Claude CLI a
|
||||
``tool_use`` without a paired ``tool_result`` and the SDK would fail.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _tool_call(call_id: str, name: str = "bash_exec") -> dict:
|
||||
return {
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {"name": name, "arguments": "{}"},
|
||||
}
|
||||
|
||||
def test_stop_mid_tool_leaves_orphan_assistant(self) -> None:
|
||||
"""Stop between StreamToolInputAvailable and StreamToolOutputAvailable:
|
||||
the assistant row has ``tool_calls`` but no matching tool row."""
|
||||
messages: list[ChatMessage] = [
|
||||
ChatMessage(role="user", content="do something"),
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[self._tool_call("tc_abc")],
|
||||
),
|
||||
]
|
||||
|
||||
removed = prune_orphan_tool_calls(messages)
|
||||
|
||||
assert removed == 1
|
||||
assert len(messages) == 1
|
||||
assert messages[-1].role == "user"
|
||||
|
||||
def test_stop_strips_stopped_by_user_marker_and_orphan(self) -> None:
|
||||
"""The service also appends a ``STOPPED_BY_USER_MARKER`` after a
|
||||
user stop when the stream loop exits cleanly; both tail rows must go."""
|
||||
messages: list[ChatMessage] = [
|
||||
ChatMessage(role="user", content="do something"),
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[self._tool_call("tc_abc")],
|
||||
),
|
||||
ChatMessage(role="assistant", content=STOPPED_BY_USER_MARKER),
|
||||
]
|
||||
|
||||
removed = prune_orphan_tool_calls(messages)
|
||||
|
||||
assert removed == 2
|
||||
assert len(messages) == 1
|
||||
assert messages[-1].role == "user"
|
||||
|
||||
def test_completed_tool_call_is_preserved(self) -> None:
|
||||
"""An assistant row whose tool_calls are all resolved is a healthy
|
||||
trailing state and must not be popped."""
|
||||
messages: list[ChatMessage] = [
|
||||
ChatMessage(role="user", content="do something"),
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[self._tool_call("tc_abc")],
|
||||
),
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content="ok",
|
||||
tool_call_id="tc_abc",
|
||||
),
|
||||
]
|
||||
|
||||
removed = prune_orphan_tool_calls(messages)
|
||||
|
||||
assert removed == 0
|
||||
assert len(messages) == 3
|
||||
|
||||
def test_partial_resolution_still_pops(self) -> None:
|
||||
"""If an assistant emits multiple tool_calls and only some are
|
||||
resolved, the assistant row is still unsafe for ``--resume``."""
|
||||
messages: list[ChatMessage] = [
|
||||
ChatMessage(role="user", content="do something"),
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
self._tool_call("tc_1"),
|
||||
self._tool_call("tc_2"),
|
||||
],
|
||||
),
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content="ok",
|
||||
tool_call_id="tc_1",
|
||||
),
|
||||
]
|
||||
|
||||
removed = prune_orphan_tool_calls(messages)
|
||||
|
||||
# Both the orphan assistant and its partial tool row are dropped.
|
||||
assert removed == 2
|
||||
assert len(messages) == 1
|
||||
assert messages[-1].role == "user"
|
||||
|
||||
def test_plain_assistant_text_preserved(self) -> None:
|
||||
"""A regular text-only assistant tail is healthy and must be kept."""
|
||||
messages: list[ChatMessage] = [
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(role="assistant", content="hello"),
|
||||
]
|
||||
|
||||
removed = prune_orphan_tool_calls(messages)
|
||||
|
||||
assert removed == 0
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_empty_session_is_noop(self) -> None:
|
||||
messages: list[ChatMessage] = []
|
||||
assert prune_orphan_tool_calls(messages) == 0
|
||||
|
||||
|
||||
class TestPruneOrphanToolCallsLogging:
|
||||
"""``prune_orphan_tool_calls`` emits an INFO log when the caller passes
|
||||
``log_prefix`` and something was actually popped. Shared by the SDK
|
||||
and baseline turn-start cleanup so both paths log in the same shape."""
|
||||
|
||||
def _tool_call(self, call_id: str) -> dict:
|
||||
return {"id": call_id, "type": "function", "function": {"name": "bash"}}
|
||||
|
||||
def test_logs_when_something_was_pruned(self, caplog) -> None:
|
||||
import backend.copilot.session_cleanup as sc
|
||||
|
||||
messages: list[ChatMessage] = [
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(
|
||||
role="assistant", content="", tool_calls=[self._tool_call("tc_1")]
|
||||
),
|
||||
]
|
||||
|
||||
sc.logger.propagate = True
|
||||
caplog.set_level("INFO", logger=sc.logger.name)
|
||||
removed = prune_orphan_tool_calls(messages, log_prefix="[TEST] [abc123]")
|
||||
|
||||
assert removed == 1
|
||||
assert any(
|
||||
"[TEST] [abc123]" in r.message and "Dropped 1" in r.message
|
||||
for r in caplog.records
|
||||
), caplog.text
|
||||
|
||||
def test_no_log_when_nothing_to_prune(self, caplog) -> None:
|
||||
import backend.copilot.session_cleanup as sc
|
||||
|
||||
messages: list[ChatMessage] = [
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(role="assistant", content="hello"),
|
||||
]
|
||||
|
||||
sc.logger.propagate = True
|
||||
caplog.set_level("INFO", logger=sc.logger.name)
|
||||
removed = prune_orphan_tool_calls(messages, log_prefix="[TEST] [xyz]")
|
||||
|
||||
assert removed == 0
|
||||
assert not any("[TEST] [xyz]" in r.message for r in caplog.records), caplog.text
|
||||
|
||||
def test_no_log_when_log_prefix_is_none(self, caplog) -> None:
|
||||
"""Without ``log_prefix``, ``prune_orphan_tool_calls`` is silent."""
|
||||
import backend.copilot.session_cleanup as sc
|
||||
|
||||
messages: list[ChatMessage] = [
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(
|
||||
role="assistant", content="", tool_calls=[self._tool_call("tc_1")]
|
||||
),
|
||||
]
|
||||
|
||||
sc.logger.propagate = True
|
||||
caplog.set_level("INFO", logger=sc.logger.name)
|
||||
removed = prune_orphan_tool_calls(messages)
|
||||
|
||||
assert removed == 1
|
||||
assert caplog.text == ""
|
||||
|
||||
217
autogpt_platform/backend/backend/copilot/sdk/session_waiter.py
Normal file
217
autogpt_platform/backend/backend/copilot/sdk/session_waiter.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""Cross-process helpers: dispatch + await a copilot session turn.
|
||||
|
||||
The sub-AutoPilot tools (``run_sub_session``, ``get_sub_session_result``)
|
||||
and ``AutoPilotBlock`` all delegate a copilot turn to the
|
||||
``copilot_executor`` queue and then wait on the shared
|
||||
``stream_registry`` for the terminal event. This module is the
|
||||
centralised primitive so every caller agrees on the dispatch shape,
|
||||
the event aggregation, and the cleanup contract.
|
||||
|
||||
:func:`wait_for_session_result` accumulates stream events into an
|
||||
:class:`EventAccumulator` so callers get back ``response_text`` /
|
||||
``tool_calls`` / token usage in memory without an extra DB round-trip.
|
||||
|
||||
:func:`run_copilot_turn_via_queue` is the one-shot "create session meta
|
||||
→ enqueue → wait for result" sequence every caller uses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.executor.utils import enqueue_copilot_turn
|
||||
from backend.copilot.pending_message_helpers import (
|
||||
is_turn_in_flight,
|
||||
queue_user_message,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish
|
||||
|
||||
from .stream_accumulator import EventAccumulator, ToolCallEntry, process_event
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SessionOutcome = Literal["completed", "failed", "running", "queued"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionResult:
|
||||
"""Aggregated result from a copilot session turn observed via
|
||||
``stream_registry``.
|
||||
|
||||
When ``queued`` is set, :func:`run_copilot_turn_via_queue` detected an
|
||||
in-flight turn on the target session and pushed the message onto the
|
||||
pending buffer instead of starting a new turn. ``response_text`` is
|
||||
empty and the aggregate counts are zero in that case; the executor
|
||||
running the earlier turn drains the buffer on its next round.
|
||||
"""
|
||||
|
||||
response_text: str = ""
|
||||
tool_calls: list[ToolCallEntry] = field(default_factory=list)
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
queued: bool = False
|
||||
pending_buffer_length: int = 0
|
||||
|
||||
|
||||
async def wait_for_session_result(
|
||||
*,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
timeout: float,
|
||||
) -> tuple[SessionOutcome, SessionResult]:
|
||||
"""Drain the session's stream events and aggregate them into a result.
|
||||
|
||||
Returns whatever has been observed at the cap (``running`` + partial
|
||||
result) or at the terminal event (``completed`` / ``failed`` + full
|
||||
result). Cleans up the subscriber listener on every exit path so
|
||||
long-running polls don't leak listeners (sentry r3105348640).
|
||||
"""
|
||||
queue = await stream_registry.subscribe_to_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
result = SessionResult()
|
||||
if queue is None:
|
||||
# Session meta not in Redis yet, or the caller doesn't own it.
|
||||
# ``subscribe_to_session`` already retried with backoff before
|
||||
# returning None.
|
||||
return "running", result
|
||||
|
||||
acc = EventAccumulator()
|
||||
outcome: SessionOutcome = "running"
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
deadline = loop.time() + max(timeout, 0)
|
||||
while True:
|
||||
remaining = deadline - loop.time()
|
||||
if remaining <= 0:
|
||||
break
|
||||
event = await asyncio.wait_for(queue.get(), timeout=remaining)
|
||||
process_event(event, acc)
|
||||
if isinstance(event, StreamFinish):
|
||||
outcome = "completed"
|
||||
break
|
||||
if isinstance(event, StreamError):
|
||||
outcome = "failed"
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
await stream_registry.unsubscribe_from_session(
|
||||
session_id=session_id,
|
||||
subscriber_queue=queue,
|
||||
)
|
||||
|
||||
result.response_text = "".join(acc.response_parts)
|
||||
result.tool_calls = list(acc.tool_calls)
|
||||
result.prompt_tokens = acc.prompt_tokens
|
||||
result.completion_tokens = acc.completion_tokens
|
||||
result.total_tokens = acc.total_tokens
|
||||
return outcome, result
|
||||
|
||||
|
||||
async def run_copilot_turn_via_queue(
|
||||
*,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
message: str,
|
||||
timeout: float,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
) -> tuple[SessionOutcome, SessionResult]:
|
||||
"""Dispatch a copilot turn onto the queue and wait for its result.
|
||||
|
||||
The canonical invocation path shared by ``run_sub_session`` (the
|
||||
copilot tool), ``AutoPilotBlock`` (the graph block), and any future
|
||||
caller that needs to run a copilot turn without occupying its own
|
||||
worker with the SDK stream:
|
||||
|
||||
1. Create a ``stream_registry`` session meta record for the turn.
|
||||
2. Enqueue a ``CoPilotExecutionEntry`` on the copilot_execution
|
||||
exchange. Any idle copilot_executor worker claims it.
|
||||
3. Subscribe to the session's Redis stream and drain events until
|
||||
``StreamFinish`` / ``StreamError`` or the cap fires.
|
||||
|
||||
``tool_call_id`` / ``tool_name`` disambiguate who originated the
|
||||
turn in observability / replay (e.g. ``"sub:<parent>"`` for a
|
||||
sub-session, ``"autopilot_block"`` for an AutoPilotBlock run).
|
||||
|
||||
Self-defensive queue-fallback: if the target session already has a
|
||||
turn running (another ``run_sub_session`` / AutoPilot block / UI
|
||||
chat), don't race it on the cluster lock. Push the message onto the
|
||||
pending buffer so the existing turn drains it at its next round
|
||||
boundary, then:
|
||||
|
||||
* ``timeout == 0`` — return immediately with
|
||||
``("queued", SessionResult(queued=True, ...))``. Callers that
|
||||
explicitly opted into fire-and-forget (``run_sub_session`` with
|
||||
``wait_for_result=0``) use this to bail without waiting.
|
||||
* ``timeout > 0`` — **subscribe to the in-flight turn's stream and
|
||||
return its aggregated result** (exactly the same shape as a
|
||||
normally-dispatched turn, but with ``result.queued=True`` so
|
||||
callers can tell we rode on someone else's turn). Semantically
|
||||
identical to "I asked the session to do something and here is
|
||||
what happened next"; no separate deferred-state branch needed in
|
||||
``run_sub_session`` / ``AutoPilotBlock``.
|
||||
"""
|
||||
if await is_turn_in_flight(session_id):
|
||||
logger.info(
|
||||
"[queue] session=%s has a turn in flight; queueing message "
|
||||
"(tool=%s) into pending buffer instead of starting a new turn",
|
||||
session_id[:12],
|
||||
tool_name,
|
||||
)
|
||||
state = await queue_user_message(session_id=session_id, message=message)
|
||||
if timeout <= 0:
|
||||
# Fire-and-forget: caller explicitly asked not to wait.
|
||||
return "queued", SessionResult(
|
||||
queued=True, pending_buffer_length=state.buffer_length
|
||||
)
|
||||
# Ride the in-flight turn: subscribe to its stream and return the
|
||||
# same aggregated result shape as a fresh dispatch. The model
|
||||
# drains the pending buffer between tool rounds (baseline) or at
|
||||
# the next tool boundary via the PostToolUse hook (SDK), so the
|
||||
# response we observe will reflect our queued follow-up (or be
|
||||
# the terminal result if the in-flight turn finishes before the
|
||||
# buffer drains — in that case ``result.queued=True`` is still
|
||||
# the correct signal for the caller).
|
||||
outcome, observed = await wait_for_session_result(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
timeout=timeout,
|
||||
)
|
||||
observed.queued = True
|
||||
observed.pending_buffer_length = state.buffer_length
|
||||
return outcome, observed
|
||||
|
||||
turn_id = str(uuid.uuid4())
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
await enqueue_copilot_turn(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=message,
|
||||
turn_id=turn_id,
|
||||
permissions=permissions,
|
||||
)
|
||||
return await wait_for_session_result(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
timeout=timeout,
|
||||
)
|
||||
@@ -0,0 +1,169 @@
|
||||
"""Tests for the shared queue primitive in ``session_waiter``.
|
||||
|
||||
Focuses on the queue-on-busy fallback:
|
||||
|
||||
* ``timeout == 0`` — push into the buffer and return immediately with
|
||||
``("queued", SessionResult(queued=True, ...))``; skip registry +
|
||||
RabbitMQ entirely.
|
||||
* ``timeout > 0`` — push into the buffer, then subscribe to the
|
||||
in-flight turn's stream and return its aggregated result (with
|
||||
``queued=True`` annotation) so callers get the same shape as a
|
||||
fresh dispatch.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.sdk.session_waiter import SessionResult, run_copilot_turn_via_queue
|
||||
|
||||
_QR = type(
|
||||
"QR",
|
||||
(),
|
||||
{"buffer_length": 4, "max_buffer_length": 10, "turn_in_flight": True},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queue_branch_timeout_zero_returns_immediately():
|
||||
"""Busy + timeout=0 → no registry, no enqueue, no wait, queued result."""
|
||||
queue_mock = AsyncMock(return_value=_QR())
|
||||
create_session = AsyncMock()
|
||||
enqueue = AsyncMock()
|
||||
wait_result = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.session_waiter.is_turn_in_flight",
|
||||
new=AsyncMock(return_value=True),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.session_waiter.queue_user_message",
|
||||
new=queue_mock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.session_waiter.stream_registry.create_session",
|
||||
new=create_session,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.session_waiter.enqueue_copilot_turn",
|
||||
new=enqueue,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.session_waiter.wait_for_session_result",
|
||||
new=wait_result,
|
||||
),
|
||||
):
|
||||
outcome, result = await run_copilot_turn_via_queue(
|
||||
session_id="sess-busy",
|
||||
user_id="u1",
|
||||
message="follow-up",
|
||||
timeout=0,
|
||||
tool_call_id="sub:parent",
|
||||
tool_name="run_sub_session",
|
||||
)
|
||||
|
||||
assert outcome == "queued"
|
||||
assert isinstance(result, SessionResult)
|
||||
assert result.queued is True
|
||||
assert result.pending_buffer_length == 4
|
||||
create_session.assert_not_awaited()
|
||||
enqueue.assert_not_awaited()
|
||||
wait_result.assert_not_awaited()
|
||||
queue_mock.assert_awaited_once_with(session_id="sess-busy", message="follow-up")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queue_branch_positive_timeout_rides_inflight_turn():
|
||||
"""Busy + timeout>0 → push buffer, subscribe to in-flight turn, return
|
||||
its aggregated result with ``queued=True`` annotation."""
|
||||
queue_mock = AsyncMock(return_value=_QR())
|
||||
create_session = AsyncMock()
|
||||
enqueue = AsyncMock()
|
||||
observed = SessionResult()
|
||||
observed.response_text = "final answer from in-flight turn"
|
||||
wait_result = AsyncMock(return_value=("completed", observed))
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.session_waiter.is_turn_in_flight",
|
||||
new=AsyncMock(return_value=True),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.session_waiter.queue_user_message",
|
||||
new=queue_mock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.session_waiter.stream_registry.create_session",
|
||||
new=create_session,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.session_waiter.enqueue_copilot_turn",
|
||||
new=enqueue,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.session_waiter.wait_for_session_result",
|
||||
new=wait_result,
|
||||
),
|
||||
):
|
||||
outcome, result = await run_copilot_turn_via_queue(
|
||||
session_id="sess-busy",
|
||||
user_id="u1",
|
||||
message="follow-up",
|
||||
timeout=30.0,
|
||||
tool_call_id="autopilot_block",
|
||||
tool_name="autopilot_block",
|
||||
)
|
||||
|
||||
# We rode on the existing turn — its outcome + aggregate propagate up.
|
||||
assert outcome == "completed"
|
||||
assert result.response_text == "final answer from in-flight turn"
|
||||
# Marker so callers can tell we didn't start a fresh turn.
|
||||
assert result.queued is True
|
||||
assert result.pending_buffer_length == 4
|
||||
# Still no new registry entry / no new RabbitMQ job — that was the point.
|
||||
create_session.assert_not_awaited()
|
||||
enqueue.assert_not_awaited()
|
||||
# Subscribed to the session stream (not a new turn_id).
|
||||
wait_result.assert_awaited_once()
|
||||
assert wait_result.await_args.kwargs["session_id"] == "sess-busy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_idle_session_enqueues_normally():
|
||||
"""Idle session → registry session created, enqueued, drain waits."""
|
||||
create_session = AsyncMock()
|
||||
enqueue = AsyncMock()
|
||||
wait_result = AsyncMock(return_value=("completed", SessionResult()))
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.session_waiter.is_turn_in_flight",
|
||||
new=AsyncMock(return_value=False),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.session_waiter.stream_registry.create_session",
|
||||
new=create_session,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.session_waiter.enqueue_copilot_turn",
|
||||
new=enqueue,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.session_waiter.wait_for_session_result",
|
||||
new=wait_result,
|
||||
),
|
||||
):
|
||||
outcome, result = await run_copilot_turn_via_queue(
|
||||
session_id="sess-idle",
|
||||
user_id="u1",
|
||||
message="kick off",
|
||||
timeout=0.1,
|
||||
tool_call_id="autopilot_block",
|
||||
tool_name="autopilot_block",
|
||||
)
|
||||
|
||||
assert outcome == "completed"
|
||||
assert result.queued is False
|
||||
create_session.assert_awaited_once()
|
||||
enqueue.assert_awaited_once()
|
||||
@@ -0,0 +1,85 @@
|
||||
"""Stream event → aggregated result accumulator.
|
||||
|
||||
Consumes the same ``StreamBaseResponse`` events that fly over
|
||||
``stream_registry`` (text deltas, tool i/o, usage, errors) and folds
|
||||
them into a single :class:`EventAccumulator` state. Used by
|
||||
:func:`session_waiter.wait_for_session_result` to read events from a
|
||||
Redis Stream subscription so a different process can obtain the
|
||||
aggregated result for a session it didn't run.
|
||||
|
||||
Keeping the dispatch in one place means new event types can be added
|
||||
without drifting callers apart on what "response_text", "tool_calls",
|
||||
or token counts mean.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..response_model import (
|
||||
StreamError,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolCallEntry(BaseModel):
|
||||
"""A single tool call observed during stream consumption."""
|
||||
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
input: Any
|
||||
output: Any = None
|
||||
success: bool | None = None
|
||||
|
||||
|
||||
class EventAccumulator(BaseModel):
|
||||
"""Mutable accumulator fed by :func:`process_event`."""
|
||||
|
||||
response_parts: list[str] = Field(default_factory=list)
|
||||
tool_calls: list[ToolCallEntry] = Field(default_factory=list)
|
||||
tool_calls_by_id: dict[str, ToolCallEntry] = Field(default_factory=dict)
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
def process_event(event: object, acc: EventAccumulator) -> str | None:
|
||||
"""Fold *event* into *acc*. Returns the error text on ``StreamError``.
|
||||
|
||||
Uses structural pattern matching for dispatch per project guidelines.
|
||||
"""
|
||||
match event:
|
||||
case StreamTextDelta(delta=delta):
|
||||
acc.response_parts.append(delta)
|
||||
case StreamToolInputAvailable() as e:
|
||||
entry = ToolCallEntry(
|
||||
tool_call_id=e.toolCallId,
|
||||
tool_name=e.toolName,
|
||||
input=e.input,
|
||||
)
|
||||
acc.tool_calls.append(entry)
|
||||
acc.tool_calls_by_id[e.toolCallId] = entry
|
||||
case StreamToolOutputAvailable() as e:
|
||||
if tc := acc.tool_calls_by_id.get(e.toolCallId):
|
||||
tc.output = e.output
|
||||
tc.success = e.success
|
||||
else:
|
||||
logger.debug(
|
||||
"Received tool output for unknown tool_call_id: %s",
|
||||
e.toolCallId,
|
||||
)
|
||||
case StreamUsage() as e:
|
||||
acc.prompt_tokens += e.prompt_tokens
|
||||
acc.completion_tokens += e.completion_tokens
|
||||
acc.total_tokens += e.total_tokens
|
||||
case StreamError(errorText=err):
|
||||
return err
|
||||
return None
|
||||
@@ -0,0 +1,95 @@
|
||||
"""Unit tests for the watermark-fix logic in stream_chat_completion_sdk.
|
||||
|
||||
The fix is at the upload step: when use_resume=True and transcript_msg_count>0
|
||||
we set the JSONL coverage watermark to transcript_msg_count + 2 (the pair just
|
||||
recorded) instead of len(session.messages). This prevents the "inflated
|
||||
watermark" bug where a stale JSONL in GCS could hide missing context from
|
||||
future gap-fill checks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def _compute_jsonl_covered(
|
||||
use_resume: bool,
|
||||
transcript_msg_count: int,
|
||||
session_msg_count: int,
|
||||
) -> int:
|
||||
"""Mirror the watermark computation from ``stream_chat_completion_sdk``.
|
||||
|
||||
Extracted here so we can unit-test it independently without invoking the
|
||||
full streaming stack.
|
||||
"""
|
||||
if use_resume and transcript_msg_count > 0:
|
||||
return transcript_msg_count + 2
|
||||
return session_msg_count
|
||||
|
||||
|
||||
class TestWatermarkFix:
|
||||
"""Watermark computation logic — mirrors the finally-block in SDK service."""
|
||||
|
||||
def test_inflated_watermark_triggers_gap_fill(self):
|
||||
"""Stale JSONL (T12) with high watermark (46) → after fix, watermark=14.
|
||||
|
||||
Before fix: watermark=46 → next turn's gap check (transcript_msg_count < db-1)
|
||||
never fires because 46 >= 47-1=46, so context loss is silent.
|
||||
After fix: watermark = 12 + 2 = 14 → gap check fires (14 < 46) and
|
||||
the model receives the missing turns.
|
||||
"""
|
||||
# Simulate: use_resume=True, transcript covered T12 (12 msgs), DB now has 47
|
||||
use_resume = True
|
||||
transcript_msg_count = 12
|
||||
session_msg_count = 47 # DB count (what old code used to set watermark)
|
||||
|
||||
watermark = _compute_jsonl_covered(
|
||||
use_resume, transcript_msg_count, session_msg_count
|
||||
)
|
||||
|
||||
assert watermark == 14 # 12 + 2, NOT 47
|
||||
# Verify: the gap check would fire on next turn
|
||||
# next-turn check: transcript_msg_count < msg_count - 1 → 14 < 47-1=46 → True
|
||||
assert watermark < session_msg_count - 1
|
||||
|
||||
def test_no_false_positive_when_transcript_current(self):
|
||||
"""Transcript current (watermark=46, DB=47) → gap stays 0.
|
||||
|
||||
When the JSONL actually covers T46 (the most recent assistant turn),
|
||||
uploading watermark=46+2=48 means next turn's gap check sees
|
||||
48 >= 48-1=47 → no gap. Correct.
|
||||
"""
|
||||
use_resume = True
|
||||
transcript_msg_count = 46
|
||||
session_msg_count = 47
|
||||
|
||||
watermark = _compute_jsonl_covered(
|
||||
use_resume, transcript_msg_count, session_msg_count
|
||||
)
|
||||
|
||||
assert watermark == 48 # 46 + 2
|
||||
# Next turn: session has 48 msgs, watermark=48 → 48 >= 48-1=47 → no gap
|
||||
next_turn_session = 48
|
||||
assert watermark >= next_turn_session - 1
|
||||
|
||||
def test_fresh_session_falls_back_to_db_count(self):
|
||||
"""use_resume=False → watermark = len(session.messages) (original behaviour)."""
|
||||
use_resume = False
|
||||
transcript_msg_count = 0
|
||||
session_msg_count = 3
|
||||
|
||||
watermark = _compute_jsonl_covered(
|
||||
use_resume, transcript_msg_count, session_msg_count
|
||||
)
|
||||
|
||||
assert watermark == session_msg_count
|
||||
|
||||
def test_old_format_meta_zero_count_falls_back_to_db(self):
|
||||
"""transcript_msg_count=0 (old-format meta with no count field) → DB fallback."""
|
||||
use_resume = True
|
||||
transcript_msg_count = 0 # old-format meta or not-yet-set
|
||||
session_msg_count = 10
|
||||
|
||||
watermark = _compute_jsonl_covered(
|
||||
use_resume, transcript_msg_count, session_msg_count
|
||||
)
|
||||
|
||||
assert watermark == session_msg_count
|
||||
@@ -62,11 +62,24 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Max MCP response size in chars. 100K chars ≈ 25K tokens. The SDK writes oversized results to tool-results/ files.
|
||||
# Set to 100K (down from a previous 500K) because the SDK already reads back large results from disk via
|
||||
# tool-results/ — sending 500K chars inline bloated the context window and caused cache-miss thrashing.
|
||||
# 100K keeps the common case (block output, API responses) in-band without punishing the context budget.
|
||||
_MCP_MAX_CHARS = 100_000
|
||||
# Max MCP response size in chars — sized to the Claude CLI's internal cap.
|
||||
#
|
||||
# The CLI has a default ``maxResultSizeChars = 1e5`` (100K chars) annotation
|
||||
# for MCP tool results, but the actual trigger is TOKEN-based (see
|
||||
# ``sizeEstimateTokens`` in the bundled CLI at ``tengu_mcp_large_result_handled``)
|
||||
# and fires around 20–25K tokens. For JSON-heavy tool output (~3–4 chars/token)
|
||||
# that lands anywhere from ~60K to ~100K chars in practice; we've observed the
|
||||
# error path at 81K chars in production. When it fires, the CLI persists the
|
||||
# full output to disk and REPLACES the returned content with a synthetic
|
||||
# ``"Error: result (N characters) exceeds maximum allowed tokens. Output has
|
||||
# been saved to …"`` message — which destroys any `<user_follow_up>` block
|
||||
# we injected.
|
||||
#
|
||||
# 70K gives us headroom below the observed 81K trigger and leaves ~6K for the
|
||||
# follow-up injection plus CLI wire overhead. Oversized content is still
|
||||
# reachable via ``read_tool_result`` against the persisted disk file; only
|
||||
# the inline reply to this specific call is truncated.
|
||||
_MCP_MAX_CHARS = 70_000
|
||||
|
||||
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
|
||||
MCP_SERVER_NAME = "copilot"
|
||||
@@ -248,7 +261,14 @@ async def _execute_tool_sync(
|
||||
session: ChatSession,
|
||||
args: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a tool synchronously and return MCP-formatted response."""
|
||||
"""Execute a tool inline and return an MCP-formatted response.
|
||||
|
||||
The call runs to completion — no per-handler timeout, no parking. The
|
||||
stream-level idle timer in ``_run_stream_attempt`` pauses while a tool
|
||||
is pending, so a long sub-AutoPilot / graph execution doesn't trip the
|
||||
30-min idle safety net (SECRT-2247). A genuine hang is handled by the
|
||||
broader session lifecycle (user closes the tab / cancel endpoint).
|
||||
"""
|
||||
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
||||
result = await base_tool.execute(
|
||||
user_id=user_id,
|
||||
@@ -612,8 +632,12 @@ def _make_truncating_wrapper(
|
||||
else:
|
||||
_clear_tool_failures(tool_name)
|
||||
|
||||
# Stash BEFORE stripping so the frontend SSE stream receives
|
||||
# the full output including _STRIP_FROM_LLM fields (e.g. is_dry_run).
|
||||
# Stash the raw tool output for the frontend SSE stream so widgets
|
||||
# (bash, tool viewers) receive clean JSON. Mid-turn user follow-up
|
||||
# injection for MCP + built-in tools is now handled uniformly by
|
||||
# the ``PostToolUse`` hook via ``additionalContext`` so Claude sees
|
||||
# the follow-up attached to the tool_result without mutating the
|
||||
# frontend-facing payload.
|
||||
if not truncated.get("isError"):
|
||||
text = _text_from_mcp_result(truncated)
|
||||
if text:
|
||||
|
||||
@@ -251,7 +251,10 @@ class TestTruncationAndStashIntegration:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_tool(name: str, output: str = "result") -> MagicMock:
|
||||
def _make_mock_tool(
|
||||
name: str,
|
||||
output: str = "result",
|
||||
) -> MagicMock:
|
||||
"""Return a BaseTool mock that returns a successful StreamToolOutputAvailable."""
|
||||
tool = MagicMock()
|
||||
tool.name = name
|
||||
@@ -336,6 +339,38 @@ class TestCreateToolHandler:
|
||||
assert mock_tool.execute.await_count == 2
|
||||
|
||||
|
||||
class TestToolInlineExecution:
|
||||
"""Tools run inline to completion — no per-handler timeout, no parking."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init(self):
|
||||
_init_ctx(session=_make_mock_session())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_runs_to_completion_regardless_of_duration(self):
|
||||
"""A tool that takes a while still runs inline; the handler does not
|
||||
park, cancel, or wrap it in a timeout. The stream-level idle timer
|
||||
(in _run_stream_attempt) is what pauses while tool calls are pending."""
|
||||
|
||||
async def slow_but_completes(*_args, **_kwargs):
|
||||
await asyncio.sleep(0.1)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId="t1",
|
||||
output="final-result",
|
||||
toolName="slow_tool",
|
||||
success=True,
|
||||
)
|
||||
|
||||
mock_tool = _make_mock_tool("slow_tool")
|
||||
mock_tool.execute = AsyncMock(side_effect=slow_but_completes)
|
||||
|
||||
handler = create_tool_handler(mock_tool)
|
||||
result = await handler({})
|
||||
|
||||
assert result["isError"] is False
|
||||
assert "final-result" in result["content"][0]["text"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression tests: bugs fixed by removing pre-launch mechanism
|
||||
#
|
||||
@@ -873,7 +908,9 @@ class TestStripLlmFields:
|
||||
"""
|
||||
dry_run_session = MagicMock()
|
||||
dry_run_session.dry_run = True
|
||||
set_execution_context(user_id="test", session=dry_run_session, sandbox=None, sdk_cwd="/tmp/test") # type: ignore[arg-type]
|
||||
set_execution_context(
|
||||
user_id="test", session=dry_run_session, sandbox=None, sdk_cwd="/tmp/test"
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
full_payload = '{"message": "done", "is_dry_run": true}'
|
||||
|
||||
@@ -906,7 +943,9 @@ class TestStripLlmFields:
|
||||
"""
|
||||
normal_session = MagicMock()
|
||||
normal_session.dry_run = False
|
||||
set_execution_context(user_id="test", session=normal_session, sandbox=None, sdk_cwd="/tmp/test") # type: ignore[arg-type]
|
||||
set_execution_context(
|
||||
user_id="test", session=normal_session, sandbox=None, sdk_cwd="/tmp/test"
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
full_payload = '{"message": "simulated", "is_dry_run": true}'
|
||||
|
||||
@@ -929,3 +968,53 @@ class TestStripLlmFields:
|
||||
stashed = pop_pending_tool_output("fake_tool_normal")
|
||||
assert stashed is not None
|
||||
assert '"is_dry_run": true' in stashed
|
||||
|
||||
|
||||
class TestTruncatingWrapperLeavesOutputUntouched:
|
||||
"""Mid-turn drain moved to the shared ``PostToolUse`` hook path so every
|
||||
tool (MCP + built-in) is covered uniformly. The wrapper must therefore
|
||||
forward tool output verbatim and never touch ``<user_follow_up>``."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrapper_does_not_inject_followup(self):
|
||||
session = MagicMock()
|
||||
session.dry_run = False
|
||||
session.session_id = "sess-no-inject"
|
||||
set_execution_context(user_id="u", session=session, sandbox=None, sdk_cwd="/tmp/test") # type: ignore[arg-type]
|
||||
|
||||
async def fake_tool_fn(_args: dict) -> dict:
|
||||
return {
|
||||
"content": [{"type": "text", "text": "CLEAN_OUTPUT"}],
|
||||
"isError": False,
|
||||
}
|
||||
|
||||
wrapper = _make_truncating_wrapper(fake_tool_fn, "fake_tool_clean")
|
||||
result = await wrapper({})
|
||||
|
||||
text = result["content"][0]["text"]
|
||||
assert text == "CLEAN_OUTPUT"
|
||||
assert "<user_follow_up>" not in text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stash_stays_clean(self):
|
||||
"""The frontend-facing stash must be a byte-for-byte copy of the
|
||||
raw tool output (needed for JSON.parse in the bash widget)."""
|
||||
session = MagicMock()
|
||||
session.dry_run = False
|
||||
session.session_id = "sess-stash"
|
||||
set_execution_context(user_id="u", session=session, sandbox=None, sdk_cwd="/tmp/test") # type: ignore[arg-type]
|
||||
|
||||
clean_json = '{"stdout": "hello\\n", "exit_code": 0}'
|
||||
|
||||
async def fake_tool_fn(_args: dict) -> dict:
|
||||
return {
|
||||
"content": [{"type": "text", "text": clean_json}],
|
||||
"isError": False,
|
||||
}
|
||||
|
||||
wrapper = _make_truncating_wrapper(fake_tool_fn, "fake_tool_stash_pure")
|
||||
await wrapper({})
|
||||
|
||||
stashed = pop_pending_tool_output("fake_tool_stash_pure")
|
||||
assert stashed == clean_json
|
||||
assert "<user_follow_up>" not in (stashed or "")
|
||||
|
||||
@@ -12,18 +12,20 @@ from backend.copilot.transcript import (
|
||||
ENTRY_TYPE_MESSAGE,
|
||||
STOP_REASON_END_TURN,
|
||||
STRIPPABLE_TYPES,
|
||||
TRANSCRIPT_STORAGE_PREFIX,
|
||||
TranscriptDownload,
|
||||
TranscriptMode,
|
||||
cleanup_stale_project_dirs,
|
||||
cli_session_path,
|
||||
compact_transcript,
|
||||
delete_transcript,
|
||||
detect_gap,
|
||||
download_transcript,
|
||||
extract_context_messages,
|
||||
projects_base,
|
||||
read_compacted_entries,
|
||||
restore_cli_session,
|
||||
strip_for_upload,
|
||||
strip_progress_entries,
|
||||
strip_stale_thinking_blocks,
|
||||
upload_cli_session,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
@@ -34,18 +36,20 @@ __all__ = [
|
||||
"ENTRY_TYPE_MESSAGE",
|
||||
"STOP_REASON_END_TURN",
|
||||
"STRIPPABLE_TYPES",
|
||||
"TRANSCRIPT_STORAGE_PREFIX",
|
||||
"TranscriptDownload",
|
||||
"TranscriptMode",
|
||||
"cleanup_stale_project_dirs",
|
||||
"cli_session_path",
|
||||
"compact_transcript",
|
||||
"delete_transcript",
|
||||
"detect_gap",
|
||||
"download_transcript",
|
||||
"extract_context_messages",
|
||||
"projects_base",
|
||||
"read_compacted_entries",
|
||||
"restore_cli_session",
|
||||
"strip_for_upload",
|
||||
"strip_progress_entries",
|
||||
"strip_stale_thinking_blocks",
|
||||
"upload_cli_session",
|
||||
"upload_transcript",
|
||||
"validate_transcript",
|
||||
"write_transcript_to_tempfile",
|
||||
|
||||
@@ -297,8 +297,8 @@ class TestStripProgressEntries:
|
||||
|
||||
class TestDeleteTranscript:
|
||||
@pytest.mark.asyncio
|
||||
async def test_deletes_both_jsonl_and_meta(self):
|
||||
"""delete_transcript removes both the .jsonl and .meta.json files."""
|
||||
async def test_deletes_cli_session_and_meta(self):
|
||||
"""delete_transcript removes the CLI session .jsonl and .meta.json."""
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.delete = AsyncMock()
|
||||
|
||||
@@ -309,7 +309,7 @@ class TestDeleteTranscript:
|
||||
):
|
||||
await delete_transcript("user-123", "session-456")
|
||||
|
||||
assert mock_storage.delete.call_count == 3
|
||||
assert mock_storage.delete.call_count == 2
|
||||
paths = [call.args[0] for call in mock_storage.delete.call_args_list]
|
||||
assert any(p.endswith(".jsonl") for p in paths)
|
||||
assert any(p.endswith(".meta.json") for p in paths)
|
||||
@@ -319,7 +319,7 @@ class TestDeleteTranscript:
|
||||
"""If .jsonl delete fails, .meta.json delete is still attempted."""
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.delete = AsyncMock(
|
||||
side_effect=[Exception("jsonl delete failed"), None, None]
|
||||
side_effect=[Exception("jsonl delete failed"), None]
|
||||
)
|
||||
|
||||
with patch(
|
||||
@@ -330,14 +330,14 @@ class TestDeleteTranscript:
|
||||
# Should not raise
|
||||
await delete_transcript("user-123", "session-456")
|
||||
|
||||
assert mock_storage.delete.call_count == 3
|
||||
assert mock_storage.delete.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_meta_delete_failure(self):
|
||||
"""If .meta.json delete fails, no exception propagates."""
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.delete = AsyncMock(
|
||||
side_effect=[None, Exception("meta delete failed"), None]
|
||||
side_effect=[None, Exception("meta delete failed")]
|
||||
)
|
||||
|
||||
with patch(
|
||||
@@ -1015,7 +1015,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1044,7 +1044,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1070,7 +1070,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1096,7 +1096,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1118,7 +1118,7 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
nonexistent = str(tmp_path / "does-not-exist" / "projects")
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: nonexistent,
|
||||
)
|
||||
|
||||
@@ -1137,7 +1137,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1165,7 +1165,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1189,7 +1189,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1368,3 +1368,172 @@ class TestStripStaleThinkingBlocks:
|
||||
# Both entries of last turn (msg_last) preserved
|
||||
assert lines[1]["message"]["content"][0]["type"] == "thinking"
|
||||
assert lines[2]["message"]["content"][0]["type"] == "text"
|
||||
|
||||
|
||||
class TestProcessCliRestore:
|
||||
"""``process_cli_restore`` validates, strips, and writes CLI session to disk."""
|
||||
|
||||
def test_writes_stripped_bytes_not_raw(self, tmp_path):
|
||||
"""Stripped bytes (not raw bytes) must be written to disk for --resume."""
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from backend.copilot.sdk.service import process_cli_restore
|
||||
from backend.copilot.transcript import TranscriptDownload
|
||||
|
||||
session_id = "12345678-0000-0000-0000-abcdef000001"
|
||||
sdk_cwd = str(tmp_path)
|
||||
projects_base_dir = str(tmp_path)
|
||||
|
||||
# Build raw content with a strippable progress entry + a valid user/assistant pair
|
||||
raw_content = (
|
||||
'{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n'
|
||||
'{"type":"user","uuid":"u1","parentUuid":null,"message":{"role":"user","content":"hi"}}\n'
|
||||
'{"type":"assistant","uuid":"a1","parentUuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello"}]}}\n'
|
||||
)
|
||||
raw_bytes = raw_content.encode("utf-8")
|
||||
restore = TranscriptDownload(content=raw_bytes, message_count=2, mode="sdk")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.projects_base",
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.transcript.projects_base",
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
):
|
||||
stripped_str, ok = process_cli_restore(
|
||||
restore, sdk_cwd, session_id, "[Test]"
|
||||
)
|
||||
|
||||
assert ok, "Expected successful restore"
|
||||
|
||||
# Find the written session file
|
||||
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
session_file = Path(projects_base_dir) / encoded_cwd / f"{session_id}.jsonl"
|
||||
assert session_file.exists(), "Session file should have been written"
|
||||
|
||||
written_bytes = session_file.read_bytes()
|
||||
# The written bytes must be the stripped version (no progress entry)
|
||||
assert (
|
||||
b"progress" not in written_bytes
|
||||
), "Raw bytes with progress entry should not have been written"
|
||||
assert (
|
||||
b"hello" in written_bytes
|
||||
), "Stripped content should still contain assistant turn"
|
||||
|
||||
# Written bytes must equal the stripped string re-encoded
|
||||
assert written_bytes == stripped_str.encode(
|
||||
"utf-8"
|
||||
), "Written bytes must equal stripped content"
|
||||
|
||||
def test_invalid_content_returns_false(self):
|
||||
"""Content that fails validation after strip returns (empty, False)."""
|
||||
from backend.copilot.sdk.service import process_cli_restore
|
||||
from backend.copilot.transcript import TranscriptDownload
|
||||
|
||||
# A single progress-only entry — stripped result will be empty/invalid
|
||||
raw_content = '{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n'
|
||||
restore = TranscriptDownload(
|
||||
content=raw_content.encode("utf-8"), message_count=1, mode="sdk"
|
||||
)
|
||||
|
||||
stripped_str, ok = process_cli_restore(
|
||||
restore,
|
||||
"/tmp/nonexistent-sdk-cwd",
|
||||
"12345678-0000-0000-0000-000000000099",
|
||||
"[Test]",
|
||||
)
|
||||
|
||||
assert not ok
|
||||
assert stripped_str == ""
|
||||
|
||||
|
||||
class TestReadCliSessionFromDisk:
|
||||
"""``read_cli_session_from_disk`` reads, strips, and optionally writes back the session."""
|
||||
|
||||
def _build_session_file(self, tmp_path, session_id: str):
|
||||
"""Build the session file path inside tmp_path using the same encoding as cli_session_path."""
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
sdk_cwd = str(tmp_path)
|
||||
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
session_dir = Path(str(tmp_path)) / encoded_cwd
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
return sdk_cwd, session_dir / f"{session_id}.jsonl"
|
||||
|
||||
def test_returns_raw_bytes_for_invalid_utf8(self, tmp_path):
|
||||
"""Non-UTF-8 bytes trigger UnicodeDecodeError — returns raw bytes (upload-raw fallback)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from backend.copilot.sdk.service import read_cli_session_from_disk
|
||||
|
||||
session_id = "12345678-0000-0000-0000-aabbccdd0001"
|
||||
projects_base_dir = str(tmp_path)
|
||||
sdk_cwd, session_file = self._build_session_file(tmp_path, session_id)
|
||||
|
||||
# Write raw invalid UTF-8 bytes
|
||||
session_file.write_bytes(b"\xff\xfe invalid utf-8\n")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.projects_base",
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.transcript.projects_base",
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
):
|
||||
result = read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
|
||||
|
||||
# UnicodeDecodeError path returns the raw bytes (upload-raw fallback)
|
||||
assert result == b"\xff\xfe invalid utf-8\n"
|
||||
|
||||
def test_write_back_oserror_still_returns_stripped_bytes(self, tmp_path):
|
||||
"""OSError on write-back returns stripped bytes for GCS upload (not raw)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from backend.copilot.sdk.service import read_cli_session_from_disk
|
||||
|
||||
session_id = "12345678-0000-0000-0000-aabbccdd0002"
|
||||
projects_base_dir = str(tmp_path)
|
||||
sdk_cwd, session_file = self._build_session_file(tmp_path, session_id)
|
||||
|
||||
# Content with a strippable progress entry so stripped_bytes < raw_bytes
|
||||
raw_content = (
|
||||
'{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n'
|
||||
'{"type":"user","uuid":"u1","parentUuid":null,"message":{"role":"user","content":"hi"}}\n'
|
||||
'{"type":"assistant","uuid":"a1","parentUuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello"}]}}\n'
|
||||
)
|
||||
session_file.write_bytes(raw_content.encode("utf-8"))
|
||||
# Make the file read-only so write_bytes raises OSError on the write-back
|
||||
session_file.chmod(0o444)
|
||||
|
||||
try:
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.projects_base",
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.transcript.projects_base",
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
):
|
||||
result = read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
|
||||
finally:
|
||||
session_file.chmod(0o644)
|
||||
|
||||
# Must return stripped bytes (not raw, not None) so GCS gets the clean version
|
||||
assert result is not None
|
||||
assert (
|
||||
b"progress" not in result
|
||||
), "Stripped bytes must not contain progress entry"
|
||||
assert b"hello" in result, "Stripped bytes should contain assistant turn"
|
||||
|
||||
@@ -26,7 +26,7 @@ from backend.data.understanding import (
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
from backend.util.settings import AppEnvironment, Settings
|
||||
|
||||
from .config import ChatConfig
|
||||
from .config import ChatConfig, CopilotLlmModel
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSessionInfo,
|
||||
@@ -40,6 +40,21 @@ logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def resolve_chat_model(tier: CopilotLlmModel | None) -> str:
|
||||
"""Return the configured OpenRouter model string for the given tier.
|
||||
|
||||
Shared by the baseline (fast) and SDK (extended thinking) paths so
|
||||
both honor the same standard/advanced env-var configuration. ``None``
|
||||
and ``'standard'`` fall through to ``config.model``; ``'advanced'``
|
||||
uses ``config.advanced_model``. Keep this flat — if a third tier
|
||||
shows up later, extend here and both paths pick it up for free.
|
||||
"""
|
||||
if tier == "advanced":
|
||||
return config.advanced_model
|
||||
return config.model
|
||||
|
||||
|
||||
_client: LangfuseAsyncOpenAI | None = None
|
||||
_langfuse = None
|
||||
|
||||
@@ -446,7 +461,9 @@ async def inject_user_context(
|
||||
+ final_message
|
||||
)
|
||||
|
||||
for session_msg in session_messages:
|
||||
# Scan in reverse so we target the current turn's user message, not
|
||||
# an older one that may exist when pending messages have been drained.
|
||||
for session_msg in reversed(session_messages):
|
||||
if session_msg.role == "user":
|
||||
# Only touch the DB / in-memory state when the content actually
|
||||
# needs to change — avoids an unnecessary write on the common
|
||||
|
||||
@@ -61,18 +61,23 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
# (CLI version, platform). When that happens, multi-turn still works
|
||||
# via conversation compression (non-resume path), but we can't test
|
||||
# the --resume round-trip.
|
||||
transcript = None
|
||||
cli_session = None
|
||||
for _ in range(10):
|
||||
await asyncio.sleep(0.5)
|
||||
transcript = await download_transcript(test_user_id, session.session_id)
|
||||
if transcript:
|
||||
cli_session = await download_transcript(test_user_id, session.session_id)
|
||||
# Wait until both the session bytes AND the message_count watermark are
|
||||
# present — a session with message_count=0 means the .meta.json hasn't
|
||||
# been uploaded yet, so --resume on the next turn would skip gap-fill.
|
||||
if cli_session and cli_session.message_count > 0:
|
||||
break
|
||||
if not transcript:
|
||||
if not cli_session:
|
||||
return pytest.skip(
|
||||
"CLI did not produce a usable transcript — "
|
||||
"cannot test --resume round-trip in this environment"
|
||||
)
|
||||
logger.info(f"Turn 1 transcript uploaded: {len(transcript.content)} bytes")
|
||||
logger.info(
|
||||
f"Turn 1 CLI session uploaded: {len(cli_session.content)} bytes, msg_count={cli_session.message_count}"
|
||||
)
|
||||
|
||||
# Reload session for turn 2
|
||||
session = await get_chat_session(session.session_id, test_user_id)
|
||||
|
||||
77
autogpt_platform/backend/backend/copilot/session_cleanup.py
Normal file
77
autogpt_platform/backend/backend/copilot/session_cleanup.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Pre-turn cleanup of transient markers left on ``session.messages`` by
|
||||
prior turns (user-initiated Stop, cancelled tool calls, etc.).
|
||||
|
||||
Shared by both the SDK and baseline chat entry points so both code paths
|
||||
start every new turn from a well-formed message list.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from backend.copilot.constants import STOPPED_BY_USER_MARKER
|
||||
from backend.copilot.model import ChatMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prune_orphan_tool_calls(
|
||||
messages: list[ChatMessage],
|
||||
log_prefix: str | None = None,
|
||||
) -> int:
|
||||
"""Pop trailing orphan tool-use blocks from *messages* in place.
|
||||
|
||||
A Stop mid-tool-call leaves the session ending on an assistant message
|
||||
whose ``tool_calls`` have no matching ``role="tool"`` row — the tool
|
||||
never produced output because the executor was cancelled. Feeding that
|
||||
tail to the next ``--resume`` turn would hand the Claude CLI a
|
||||
``tool_use`` with no paired ``tool_result`` and the SDK raises a
|
||||
generic error.
|
||||
|
||||
Also strips trailing ``STOPPED_BY_USER_MARKER`` assistant rows emitted
|
||||
by the same Stop path so the next turn's transcript starts clean.
|
||||
|
||||
If *log_prefix* is given, emits an INFO log with the prefix whenever
|
||||
something was actually popped so the turn-start cleanup is visible.
|
||||
|
||||
In-memory only — the DB write path is append-only via
|
||||
``start_sequence`` so no delete is needed; the same rows are popped
|
||||
again on the next session load.
|
||||
"""
|
||||
cut_index: int | None = None
|
||||
resolved_ids: set[str] = set()
|
||||
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
msg = messages[i]
|
||||
|
||||
if msg.role == "tool" and msg.tool_call_id:
|
||||
resolved_ids.add(msg.tool_call_id)
|
||||
continue
|
||||
|
||||
if msg.role == "assistant" and msg.content == STOPPED_BY_USER_MARKER:
|
||||
cut_index = i
|
||||
continue
|
||||
|
||||
if msg.role == "assistant" and msg.tool_calls:
|
||||
pending_ids = {
|
||||
tc.get("id")
|
||||
for tc in msg.tool_calls
|
||||
if isinstance(tc, dict) and tc.get("id")
|
||||
}
|
||||
if pending_ids and not pending_ids.issubset(resolved_ids):
|
||||
cut_index = i
|
||||
break
|
||||
|
||||
break
|
||||
|
||||
if cut_index is None:
|
||||
return 0
|
||||
|
||||
removed = len(messages) - cut_index
|
||||
del messages[cut_index:]
|
||||
if log_prefix:
|
||||
logger.info(
|
||||
"%s Dropped %d trailing orphan tool-use/stop row(s) "
|
||||
"before starting new turn",
|
||||
log_prefix,
|
||||
removed,
|
||||
)
|
||||
return removed
|
||||
@@ -17,7 +17,7 @@ Subscribers:
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
@@ -32,9 +32,10 @@ from backend.data.notification_bus import (
|
||||
NotificationEvent,
|
||||
)
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.redis_helpers import hash_compare_and_set
|
||||
|
||||
from .config import ChatConfig
|
||||
from .executor.utils import COPILOT_CONSUMER_TIMEOUT_SECONDS
|
||||
from .executor.utils import COPILOT_CONSUMER_TIMEOUT_SECONDS, get_session_lock_key
|
||||
from .response_model import (
|
||||
ResponseType,
|
||||
StreamBaseResponse,
|
||||
@@ -42,6 +43,9 @@ from .response_model import (
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamReasoningDelta,
|
||||
StreamReasoningEnd,
|
||||
StreamReasoningStart,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
@@ -68,17 +72,6 @@ _listener_sessions: dict[int, tuple[str, asyncio.Task]] = {}
|
||||
# If the queue is full and doesn't drain within this time, send an overflow error
|
||||
QUEUE_PUT_TIMEOUT = 5.0
|
||||
|
||||
# Lua script for atomic compare-and-swap status update (idempotent completion)
|
||||
# Returns 1 if status was updated, 0 if already completed/failed
|
||||
COMPLETE_SESSION_SCRIPT = """
|
||||
local current = redis.call("HGET", KEYS[1], "status")
|
||||
if current == "running" then
|
||||
redis.call("HSET", KEYS[1], "status", ARGV[1])
|
||||
return 1
|
||||
end
|
||||
return 0
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveSession:
|
||||
@@ -336,8 +329,8 @@ async def publish_chunk(
|
||||
async def stream_and_publish(
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
stream: AsyncIterator[StreamBaseResponse],
|
||||
) -> AsyncIterator[StreamBaseResponse]:
|
||||
stream: AsyncGenerator[StreamBaseResponse, None],
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Wrap an async stream iterator with registry publishing.
|
||||
|
||||
Publishes each chunk to the stream registry for frontend SSE consumption,
|
||||
@@ -360,27 +353,35 @@ async def stream_and_publish(
|
||||
"""
|
||||
publish_failed_once = False
|
||||
|
||||
async for event in stream:
|
||||
if turn_id and not isinstance(event, (StreamFinish, StreamError)):
|
||||
try:
|
||||
await publish_chunk(turn_id, event, session_id=session_id)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
if not publish_failed_once:
|
||||
publish_failed_once = True
|
||||
logger.warning(
|
||||
"[stream_and_publish] Failed to publish chunk %s for %s "
|
||||
"(further failures logged at DEBUG)",
|
||||
type(event).__name__,
|
||||
session_id[:12],
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"[stream_and_publish] Failed to publish chunk %s",
|
||||
type(event).__name__,
|
||||
exc_info=True,
|
||||
)
|
||||
yield event
|
||||
# async-for does not close an iterator on GeneratorExit; forward close
|
||||
# to ``stream`` explicitly so its own cleanup (stream lock, persist)
|
||||
# runs deterministically instead of waiting for GC.
|
||||
try:
|
||||
async for event in stream:
|
||||
if turn_id and not isinstance(event, (StreamFinish, StreamError)):
|
||||
try:
|
||||
await publish_chunk(turn_id, event, session_id=session_id)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
# Full stack trace on the first failure; terser lines
|
||||
# for the rest so subsequent failures don't flood logs
|
||||
# while still being visible at WARNING.
|
||||
if not publish_failed_once:
|
||||
publish_failed_once = True
|
||||
logger.warning(
|
||||
"[stream_and_publish] Failed to publish chunk %s for %s",
|
||||
type(event).__name__,
|
||||
session_id[:12],
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"[stream_and_publish] Failed to publish chunk %s for %s",
|
||||
type(event).__name__,
|
||||
session_id[:12],
|
||||
)
|
||||
yield event
|
||||
finally:
|
||||
await stream.aclose()
|
||||
|
||||
|
||||
async def subscribe_to_session(
|
||||
@@ -423,20 +424,33 @@ async def subscribe_to_session(
|
||||
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
|
||||
)
|
||||
|
||||
# RACE CONDITION FIX: If session not found, retry once after small delay
|
||||
# This handles the case where subscribe_to_session is called immediately
|
||||
# after create_session but before Redis propagates the write
|
||||
# RACE CONDITION FIX: If session not found, retry with backoff.
|
||||
# Duplicate requests skip create_session and subscribe immediately; the
|
||||
# original request's create_session (a Redis hset) may not have completed
|
||||
# yet. 3 × 100ms gives a 300ms window which covers DB-write latency on the
|
||||
# original request before the hset even starts.
|
||||
if not meta:
|
||||
logger.warning(
|
||||
"[TIMING] Session not found on first attempt, retrying after 50ms delay",
|
||||
extra={"json_fields": {**log_meta}},
|
||||
)
|
||||
await asyncio.sleep(0.05) # 50ms
|
||||
meta = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
if not meta:
|
||||
_max_retries = 3
|
||||
_retry_delay = 0.1 # 100ms per attempt
|
||||
for attempt in range(_max_retries):
|
||||
logger.warning(
|
||||
f"[TIMING] Session not found (attempt {attempt + 1}/{_max_retries}), "
|
||||
f"retrying after {int(_retry_delay * 1000)}ms",
|
||||
extra={"json_fields": {**log_meta, "attempt": attempt + 1}},
|
||||
)
|
||||
await asyncio.sleep(_retry_delay)
|
||||
meta = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
if meta:
|
||||
logger.info(
|
||||
f"[TIMING] Session found after {attempt + 1} retries",
|
||||
extra={"json_fields": {**log_meta, "attempts": attempt + 1}},
|
||||
)
|
||||
break
|
||||
else:
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Session still not found in Redis after retry ({elapsed:.1f}ms total)",
|
||||
f"[TIMING] Session still not found in Redis after {_max_retries} retries "
|
||||
f"({elapsed:.1f}ms total)",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
@@ -446,10 +460,6 @@ async def subscribe_to_session(
|
||||
},
|
||||
)
|
||||
return None
|
||||
logger.info(
|
||||
"[TIMING] Session found after retry",
|
||||
extra={"json_fields": {**log_meta}},
|
||||
)
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||
session_status = meta.get("status", "")
|
||||
@@ -830,15 +840,26 @@ async def mark_session_completed(
|
||||
turn_id = _parse_session_meta(meta, session_id).turn_id if meta else session_id
|
||||
|
||||
# Atomic compare-and-swap: only update if status is "running"
|
||||
result = await redis.eval(COMPLETE_SESSION_SCRIPT, 1, meta_key, status) # type: ignore[misc]
|
||||
swapped = await hash_compare_and_set(
|
||||
redis, meta_key, "status", expected="running", new=status
|
||||
)
|
||||
|
||||
# Clean up the in-memory TTL refresh tracker to prevent unbounded growth.
|
||||
_meta_ttl_refresh_at.pop(session_id, None)
|
||||
|
||||
if result == 0:
|
||||
if not swapped:
|
||||
logger.debug(f"Session {session_id} already completed/failed, skipping")
|
||||
return False
|
||||
|
||||
# Force-release the executor's cluster lock so the next enqueued turn can
|
||||
# acquire it immediately. The lock holder's on_run_done will also release
|
||||
# (idempotent delete); doing it here unblocks cases where the task hangs
|
||||
# past the cancel timeout or a pod crash leaves the lock orphaned.
|
||||
try:
|
||||
await redis.delete(get_session_lock_key(session_id))
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to release cluster lock for session {session_id}: {e}")
|
||||
|
||||
if error_message and not skip_error_publish:
|
||||
try:
|
||||
await publish_chunk(turn_id, StreamError(errorText=error_message))
|
||||
@@ -1061,6 +1082,9 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
||||
ResponseType.TEXT_START.value: StreamTextStart,
|
||||
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
||||
ResponseType.TEXT_END.value: StreamTextEnd,
|
||||
ResponseType.REASONING_START.value: StreamReasoningStart,
|
||||
ResponseType.REASONING_DELTA.value: StreamReasoningDelta,
|
||||
ResponseType.REASONING_END.value: StreamReasoningEnd,
|
||||
ResponseType.TOOL_INPUT_START.value: StreamToolInputStart,
|
||||
ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable,
|
||||
ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable,
|
||||
|
||||
@@ -4,8 +4,10 @@ import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.executor.utils import get_session_lock_key
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -108,3 +110,228 @@ async def test_disconnect_all_listeners_timeout_not_counted():
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# stream_and_publish: closing the wrapper forwards GeneratorExit into the
|
||||
# inner stream so its finally (stream lock release, etc.) runs deterministically.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeEvent:
|
||||
"""Minimal stand-in for a StreamBaseResponse so publish_chunk is a no-op."""
|
||||
|
||||
def __init__(self, idx: int):
|
||||
self.idx = idx
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_and_publish_aclose_propagates_to_inner_stream():
|
||||
"""Closing the wrapper MUST run the inner generator's finally block."""
|
||||
inner_finally_ran = asyncio.Event()
|
||||
|
||||
async def _inner():
|
||||
try:
|
||||
yield _FakeEvent(0)
|
||||
yield _FakeEvent(1)
|
||||
yield _FakeEvent(2)
|
||||
finally:
|
||||
inner_finally_ran.set()
|
||||
|
||||
inner = _inner()
|
||||
# Empty turn_id skips publish_chunk — keeps the test hermetic (no Redis).
|
||||
wrapper = stream_registry.stream_and_publish(
|
||||
session_id="sess-test", turn_id="", stream=inner
|
||||
)
|
||||
|
||||
# Consume one event, then close the wrapper early.
|
||||
first = await wrapper.__anext__()
|
||||
assert isinstance(first, _FakeEvent)
|
||||
|
||||
await wrapper.aclose()
|
||||
|
||||
# The inner generator's finally must have run deterministically
|
||||
# (not deferred to GC) so the caller's cleanup (lock release, etc.)
|
||||
# is observable right after aclose returns.
|
||||
assert inner_finally_ran.is_set()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_and_publish_logs_warning_on_publish_chunk_failure():
|
||||
"""``stream_and_publish`` must not propagate a Redis publish failure —
|
||||
it warns once with full stack trace, keeps yielding, and logs
|
||||
subsequent failures at WARNING (terser, no exc_info) so repeated
|
||||
errors stay visible without flooding the trace."""
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
async def _inner():
|
||||
yield _FakeEvent(0)
|
||||
yield _FakeEvent(1)
|
||||
yield _FakeEvent(2)
|
||||
|
||||
async def _raising_publish(turn_id, event, session_id=None):
|
||||
raise RedisError("boom")
|
||||
|
||||
warning_mock = patch.object(
|
||||
stream_registry.logger, "warning", autospec=True
|
||||
).start()
|
||||
try:
|
||||
with patch.object(stream_registry, "publish_chunk", new=_raising_publish):
|
||||
wrapper = stream_registry.stream_and_publish(
|
||||
session_id="sess-test", turn_id="turn-1", stream=_inner()
|
||||
)
|
||||
received = [evt async for evt in wrapper]
|
||||
finally:
|
||||
patch.stopall()
|
||||
|
||||
# Every event still yields through — publish failures don't break the stream.
|
||||
assert len(received) == 3
|
||||
# One warning per failed publish (3 total). First call carries a
|
||||
# stack trace (``exc_info=True``); subsequent calls are terser.
|
||||
assert warning_mock.call_count == 3
|
||||
assert warning_mock.call_args_list[0].kwargs.get("exc_info") is True
|
||||
assert warning_mock.call_args_list[1].kwargs.get("exc_info") is not True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_and_publish_consumer_break_then_aclose_releases_inner():
|
||||
"""The processor pattern — break on cancel, then aclose — must release."""
|
||||
inner_finally_ran = asyncio.Event()
|
||||
|
||||
async def _inner():
|
||||
try:
|
||||
for idx in range(100):
|
||||
yield _FakeEvent(idx)
|
||||
finally:
|
||||
inner_finally_ran.set()
|
||||
|
||||
inner = _inner()
|
||||
wrapper = stream_registry.stream_and_publish(
|
||||
session_id="sess-test", turn_id="", stream=inner
|
||||
)
|
||||
|
||||
# Mimic the processor: consume a few events, simulate Stop by breaking,
|
||||
# then aclose the wrapper (as processor._execute_async now does in the
|
||||
# try/finally around the async for).
|
||||
try:
|
||||
count = 0
|
||||
async for _ in wrapper:
|
||||
count += 1
|
||||
if count >= 2:
|
||||
break
|
||||
finally:
|
||||
await wrapper.aclose()
|
||||
|
||||
assert inner_finally_ran.is_set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# mark_session_completed: the atomic meta flip to completed/failed must also
|
||||
# release the per-session cluster lock, so the next enqueued turn's run
|
||||
# handler can acquire it without waiting for the TTL (5 min default).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
"""Minimal async-Redis fake: only the calls mark_session_completed makes."""
|
||||
|
||||
def __init__(self, meta: dict[str, str]):
|
||||
self._meta = dict(meta)
|
||||
self.deleted_keys: list[str] = []
|
||||
self.delete = AsyncMock(side_effect=self._record_delete)
|
||||
|
||||
async def _record_delete(self, *keys: str):
|
||||
self.deleted_keys.extend(keys)
|
||||
for k in keys:
|
||||
self._meta.pop(k, None)
|
||||
return len(keys)
|
||||
|
||||
async def hgetall(self, _key: str):
|
||||
return dict(self._meta)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_session_completed_releases_cluster_lock_on_success():
|
||||
"""CAS swap must be followed by a DELETE on the session's lock key so a
|
||||
stuck-because-of-stale-lock session becomes immediately claimable."""
|
||||
fake_redis = _FakeRedis({"status": "running", "turn_id": "turn-1"})
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
stream_registry, "get_redis_async", new=AsyncMock(return_value=fake_redis)
|
||||
),
|
||||
patch.object(
|
||||
stream_registry, "hash_compare_and_set", new=AsyncMock(return_value=True)
|
||||
),
|
||||
patch.object(stream_registry, "publish_chunk", new=AsyncMock()),
|
||||
patch.object(
|
||||
stream_registry.chat_db(),
|
||||
"set_turn_duration",
|
||||
new=AsyncMock(),
|
||||
create=True,
|
||||
),
|
||||
):
|
||||
result = await stream_registry.mark_session_completed("sess-1")
|
||||
|
||||
assert result is True
|
||||
assert get_session_lock_key("sess-1") in fake_redis.deleted_keys
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_session_completed_skips_lock_release_when_already_completed():
|
||||
"""CAS failure = someone else completed the session first; we must not
|
||||
delete their already-released lock, and we must NOT publish StreamFinish
|
||||
twice (the winning caller already published it)."""
|
||||
fake_redis = _FakeRedis({"status": "completed", "turn_id": "turn-1"})
|
||||
publish_mock = AsyncMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
stream_registry, "get_redis_async", new=AsyncMock(return_value=fake_redis)
|
||||
),
|
||||
patch.object(
|
||||
stream_registry, "hash_compare_and_set", new=AsyncMock(return_value=False)
|
||||
),
|
||||
patch.object(stream_registry, "publish_chunk", new=publish_mock),
|
||||
):
|
||||
result = await stream_registry.mark_session_completed("sess-1")
|
||||
|
||||
assert result is False
|
||||
assert get_session_lock_key("sess-1") not in fake_redis.deleted_keys
|
||||
assert not any(
|
||||
isinstance(call.args[1], stream_registry.StreamFinish)
|
||||
for call in publish_mock.call_args_list
|
||||
), "StreamFinish must NOT be re-published on the CAS-no-op branch"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_session_completed_survives_lock_release_redis_error():
|
||||
"""A Redis hiccup during lock DELETE must not prevent the StreamFinish
|
||||
publish — the client's SSE stream would otherwise hang on the stale meta
|
||||
status while Redis recovers."""
|
||||
fake_redis = _FakeRedis({"status": "running", "turn_id": "turn-1"})
|
||||
fake_redis.delete = AsyncMock(side_effect=RedisError("boom"))
|
||||
publish_mock = AsyncMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
stream_registry, "get_redis_async", new=AsyncMock(return_value=fake_redis)
|
||||
),
|
||||
patch.object(
|
||||
stream_registry, "hash_compare_and_set", new=AsyncMock(return_value=True)
|
||||
),
|
||||
patch.object(stream_registry, "publish_chunk", new=publish_mock),
|
||||
patch.object(
|
||||
stream_registry.chat_db(),
|
||||
"set_turn_duration",
|
||||
new=AsyncMock(),
|
||||
create=True,
|
||||
),
|
||||
):
|
||||
result = await stream_registry.mark_session_completed("sess-1")
|
||||
|
||||
assert result is True
|
||||
assert any(
|
||||
isinstance(call.args[1], stream_registry.StreamFinish)
|
||||
for call in publish_mock.call_args_list
|
||||
), "StreamFinish must still be published even if lock DELETE raises"
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Shared token-usage persistence and rate-limit recording.
|
||||
"""Shared usage persistence and rate-limit recording.
|
||||
|
||||
Both the baseline (OpenRouter) and SDK (Anthropic) service layers need to:
|
||||
1. Append a ``Usage`` record to the session.
|
||||
2. Log the turn's token counts.
|
||||
3. Record weighted usage in Redis for rate-limiting.
|
||||
2. Log the turn's token counts and cost.
|
||||
3. Record the real generation cost in Redis for rate-limiting.
|
||||
4. Write a PlatformCostLog entry for admin cost tracking.
|
||||
|
||||
This module extracts that common logic so both paths stay in sync.
|
||||
@@ -19,7 +19,7 @@ from backend.data.db_accessors import platform_cost_db
|
||||
from backend.data.platform_cost import PlatformCostEntry, usd_to_microdollars
|
||||
|
||||
from .model import ChatSession, Usage
|
||||
from .rate_limit import record_token_usage
|
||||
from .rate_limit import record_cost_usage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -96,9 +96,14 @@ async def persist_and_record_usage(
|
||||
cost_usd: float | str | None = None,
|
||||
model: str | None = None,
|
||||
provider: str = "open_router",
|
||||
model_cost_multiplier: float = 1.0,
|
||||
) -> int:
|
||||
"""Persist token usage to session and record for rate limiting.
|
||||
"""Persist token usage to session and record generation cost for rate limiting.
|
||||
|
||||
Rate-limit counters are charged in microdollars against the provider's
|
||||
reported cost (``cost_usd``), so cache discounts and cross-model pricing
|
||||
differences are already reflected. When cost is unknown the turn is
|
||||
logged but the rate-limit counter is left alone — the caller logs an
|
||||
error at the point the absence is detected.
|
||||
|
||||
Args:
|
||||
session: The chat session to append usage to (may be None on error).
|
||||
@@ -108,11 +113,11 @@ async def persist_and_record_usage(
|
||||
cache_read_tokens: Tokens served from prompt cache (Anthropic only).
|
||||
cache_creation_tokens: Tokens written to prompt cache (Anthropic only).
|
||||
log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]").
|
||||
cost_usd: Optional cost for logging (float from SDK, str otherwise).
|
||||
cost_usd: Real generation cost for the turn (float from SDK or parsed
|
||||
from OpenRouter usage.cost). ``None`` means the provider did not
|
||||
report a cost and rate limiting is skipped for this turn.
|
||||
model: Model identifier for cost log attribution.
|
||||
provider: Cost provider name (e.g. "anthropic", "open_router").
|
||||
model_cost_multiplier: Relative model cost factor for rate limiting
|
||||
(1.0 = Sonnet/default, 5.0 = Opus). Scales the token counter so
|
||||
more expensive models deplete the rate limit proportionally faster.
|
||||
|
||||
Returns:
|
||||
The computed total_tokens (prompt + completion; cache excluded).
|
||||
@@ -156,37 +161,51 @@ async def persist_and_record_usage(
|
||||
else:
|
||||
logger.info(
|
||||
f"{log_prefix} Turn usage: prompt={prompt_tokens}, completion={completion_tokens},"
|
||||
f" total={total_tokens}"
|
||||
f" total={total_tokens}, cost_usd={cost_usd}"
|
||||
)
|
||||
|
||||
if user_id:
|
||||
cost_float: float | None = None
|
||||
if cost_usd is not None:
|
||||
try:
|
||||
await record_token_usage(
|
||||
user_id=user_id,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
model_cost_multiplier=model_cost_multiplier,
|
||||
val = float(cost_usd)
|
||||
except (ValueError, TypeError):
|
||||
logger.error(
|
||||
"%s cost_usd is not numeric: %r — rate limit skipped",
|
||||
log_prefix,
|
||||
cost_usd,
|
||||
)
|
||||
except Exception as usage_err:
|
||||
logger.warning("%s Failed to record token usage: %s", log_prefix, usage_err)
|
||||
else:
|
||||
if not math.isfinite(val):
|
||||
logger.error(
|
||||
"%s cost_usd is non-finite: %r — rate limit skipped",
|
||||
log_prefix,
|
||||
val,
|
||||
)
|
||||
elif val < 0:
|
||||
logger.warning(
|
||||
"%s cost_usd %s is negative — skipping rate-limit + cost log",
|
||||
log_prefix,
|
||||
val,
|
||||
)
|
||||
else:
|
||||
cost_float = val
|
||||
|
||||
cost_microdollars = usd_to_microdollars(cost_float)
|
||||
|
||||
if user_id and cost_microdollars is not None and cost_microdollars > 0:
|
||||
# record_cost_usage() owns its fail-open handling for Redis/network
|
||||
# errors. Don't wrap with a broad except here — unexpected accounting
|
||||
# bugs should surface instead of being silently logged as warnings.
|
||||
await record_cost_usage(
|
||||
user_id=user_id,
|
||||
cost_microdollars=cost_microdollars,
|
||||
)
|
||||
|
||||
# Log to PlatformCostLog for admin cost dashboard.
|
||||
# Include entries where cost_usd is set even if token count is 0
|
||||
# (e.g. fully-cached Anthropic responses where only cache tokens
|
||||
# accumulate a charge without incrementing total_tokens).
|
||||
if user_id and (total_tokens > 0 or cost_usd is not None):
|
||||
cost_float = None
|
||||
if cost_usd is not None:
|
||||
try:
|
||||
val = float(cost_usd)
|
||||
if math.isfinite(val) and val >= 0:
|
||||
cost_float = val
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
cost_microdollars = usd_to_microdollars(cost_float)
|
||||
if user_id and (total_tokens > 0 or cost_float is not None):
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if cost_float is not None:
|
||||
|
||||
@@ -37,7 +37,7 @@ class TestTotalTokens:
|
||||
async def test_returns_prompt_plus_completion(self):
|
||||
"""total_tokens = prompt + completion (cache excluded from total)."""
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
total = await persist_and_record_usage(
|
||||
@@ -63,7 +63,7 @@ class TestTotalTokens:
|
||||
async def test_cache_tokens_excluded_from_total(self):
|
||||
"""Cache tokens are stored separately and not added to total_tokens."""
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
total = await persist_and_record_usage(
|
||||
@@ -81,7 +81,7 @@ class TestTotalTokens:
|
||||
async def test_baseline_path_no_cache(self):
|
||||
"""Baseline (OpenRouter) path passes no cache tokens; total = prompt + completion."""
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
total = await persist_and_record_usage(
|
||||
@@ -97,7 +97,7 @@ class TestTotalTokens:
|
||||
async def test_sdk_path_with_cache(self):
|
||||
"""SDK (Anthropic) path passes cache tokens; total still = prompt + completion."""
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
total = await persist_and_record_usage(
|
||||
@@ -123,7 +123,7 @@ class TestSessionPersistence:
|
||||
async def test_appends_usage_to_session(self):
|
||||
session = _make_session()
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
@@ -144,7 +144,7 @@ class TestSessionPersistence:
|
||||
async def test_appends_cache_breakdown_to_session(self):
|
||||
session = _make_session()
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
@@ -163,7 +163,7 @@ class TestSessionPersistence:
|
||||
async def test_multiple_turns_append_multiple_records(self):
|
||||
session = _make_session()
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
@@ -178,7 +178,7 @@ class TestSessionPersistence:
|
||||
async def test_none_session_does_not_raise(self):
|
||||
"""When session is None (e.g. error path), no exception should be raised."""
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
total = await persist_and_record_usage(
|
||||
@@ -210,10 +210,11 @@ class TestSessionPersistence:
|
||||
|
||||
class TestRateLimitRecording:
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_record_token_usage_when_user_id_present(self):
|
||||
async def test_calls_record_cost_usage_when_cost_and_user_id_present(self):
|
||||
"""Rate-limit counter is charged with the real provider cost (microdollars)."""
|
||||
mock_record = AsyncMock()
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new=mock_record,
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
@@ -223,22 +224,35 @@ class TestRateLimitRecording:
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=1000,
|
||||
cache_creation_tokens=200,
|
||||
cost_usd=0.0123,
|
||||
)
|
||||
mock_record.assert_awaited_once_with(
|
||||
user_id="user-abc",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=1000,
|
||||
cache_creation_tokens=200,
|
||||
model_cost_multiplier=1.0,
|
||||
cost_microdollars=12_300,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_record_when_cost_is_missing(self):
|
||||
"""Without a provider cost we have no authoritative figure to charge."""
|
||||
mock_record = AsyncMock()
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new=mock_record,
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-abc",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
)
|
||||
mock_record.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_record_when_user_id_is_none(self):
|
||||
"""Anonymous sessions should not create Redis keys."""
|
||||
mock_record = AsyncMock()
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new=mock_record,
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
@@ -246,32 +260,38 @@ class TestRateLimitRecording:
|
||||
user_id=None,
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
mock_record.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_failure_does_not_raise(self):
|
||||
"""A Redis error in record_token_usage should be swallowed (fail-open)."""
|
||||
mock_record = AsyncMock(side_effect=ConnectionError("Redis down"))
|
||||
async def test_record_usage_bubbles_unexpected_error(self):
|
||||
"""Unexpected errors from record_cost_usage must propagate.
|
||||
|
||||
record_cost_usage() owns its own (RedisError, ConnectionError, OSError)
|
||||
fail-open handling. Anything else is a real accounting bug and
|
||||
should not be silently swallowed at this layer.
|
||||
"""
|
||||
mock_record = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new=mock_record,
|
||||
):
|
||||
# Should not raise
|
||||
total = await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-xyz",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
)
|
||||
assert total == 150
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-xyz",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd=0.002,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_record_when_zero_tokens(self):
|
||||
"""Returns 0 before calling record_token_usage when tokens are zero."""
|
||||
async def test_skips_record_when_zero_tokens_and_no_cost(self):
|
||||
"""Returns 0 before calling record_cost_usage when there is nothing to record."""
|
||||
mock_record = AsyncMock()
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new=mock_record,
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
@@ -295,7 +315,7 @@ class TestPlatformCostLogging:
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
@@ -336,7 +356,7 @@ class TestPlatformCostLogging:
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
@@ -369,7 +389,7 @@ class TestPlatformCostLogging:
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
@@ -394,7 +414,7 @@ class TestPlatformCostLogging:
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
@@ -423,7 +443,7 @@ class TestPlatformCostLogging:
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
@@ -452,7 +472,7 @@ class TestPlatformCostLogging:
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
@@ -479,7 +499,7 @@ class TestPlatformCostLogging:
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
@@ -509,7 +529,7 @@ class TestPlatformCostLogging:
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
@@ -545,7 +565,7 @@ class TestPlatformCostLogging:
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
"backend.copilot.token_tracking.record_cost_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
|
||||
@@ -26,6 +26,7 @@ from .fix_agent import FixAgentGraphTool
|
||||
from .get_agent_building_guide import GetAgentBuildingGuideTool
|
||||
from .get_doc_page import GetDocPageTool
|
||||
from .get_mcp_guide import GetMCPGuideTool
|
||||
from .get_sub_session_result import GetSubSessionResultTool
|
||||
from .graphiti_forget import MemoryForgetConfirmTool, MemoryForgetSearchTool
|
||||
from .graphiti_search import MemorySearchTool
|
||||
from .graphiti_store import MemoryStoreTool
|
||||
@@ -40,6 +41,7 @@ from .manage_folders import (
|
||||
from .run_agent import RunAgentTool
|
||||
from .run_block import RunBlockTool
|
||||
from .run_mcp_tool import RunMCPToolTool
|
||||
from .run_sub_session import RunSubSessionTool
|
||||
from .search_docs import SearchDocsTool
|
||||
from .validate_agent import ValidateAgentGraphTool
|
||||
from .web_fetch import WebFetchTool
|
||||
@@ -81,6 +83,8 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"run_agent": RunAgentTool(),
|
||||
"run_block": RunBlockTool(),
|
||||
"continue_run_block": ContinueRunBlockTool(),
|
||||
"run_sub_session": RunSubSessionTool(),
|
||||
"get_sub_session_result": GetSubSessionResultTool(),
|
||||
"run_mcp_tool": RunMCPToolTool(),
|
||||
"get_mcp_guide": GetMCPGuideTool(),
|
||||
"view_agent_output": AgentOutputTool(),
|
||||
|
||||
@@ -12,7 +12,7 @@ from backend.api.features.store import db as store_db
|
||||
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.data import db as db_module
|
||||
from backend.data.db import prisma
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
@@ -42,11 +42,28 @@ async def _ensure_db_connected() -> None:
|
||||
await db_module.connect()
|
||||
|
||||
|
||||
def make_session(user_id: str):
|
||||
def make_session(user_id: str, *, guide_read: bool = True):
|
||||
"""Build a fake ChatSession for tool tests.
|
||||
|
||||
``guide_read=True`` (default) pre-populates the session with a
|
||||
``get_agent_building_guide`` tool-call history entry so the agent-
|
||||
generation gate (see ``helpers.require_guide_read``) lets through any
|
||||
subsequent ``create_agent`` / ``edit_agent`` / ``validate_agent_graph``
|
||||
/ ``fix_agent_graph`` call.
|
||||
"""
|
||||
messages: list[ChatMessage] = []
|
||||
if guide_read:
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[{"function": {"name": "get_agent_building_guide"}}],
|
||||
)
|
||||
)
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
messages=[],
|
||||
messages=messages,
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
|
||||
@@ -1325,7 +1325,7 @@ class AgentFixer:
|
||||
"""
|
||||
if not library_agents:
|
||||
logger.debug(
|
||||
"fix_agent_executor_blocks: No library_agents provided, " "skipping"
|
||||
"fix_agent_executor_blocks: No library_agents provided, skipping"
|
||||
)
|
||||
return agent
|
||||
|
||||
@@ -1390,7 +1390,7 @@ class AgentFixer:
|
||||
if "user_id" not in input_default:
|
||||
input_default["user_id"] = ""
|
||||
self.add_fix_log(
|
||||
f"Fixed AgentExecutorBlock {node_id}: Added missing " f"user_id"
|
||||
f"Fixed AgentExecutorBlock {node_id}: Added missing user_id"
|
||||
)
|
||||
|
||||
# Ensure inputs is present
|
||||
@@ -1689,8 +1689,7 @@ class AgentFixer:
|
||||
if field not in input_default or input_default[field] is None:
|
||||
input_default[field] = default_value
|
||||
self.add_fix_log(
|
||||
f"OrchestratorBlock {node_id}: "
|
||||
f"Set {field}={default_value!r}"
|
||||
f"OrchestratorBlock {node_id}: Set {field}={default_value!r}"
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
"""Tests for the ``require_guide_read`` gate on agent-generation tools.
|
||||
|
||||
The agent-building guide carries block ids, link semantics, and
|
||||
AgentExecutorBlock / MCPToolBlock conventions that the agent needs before
|
||||
producing agent JSON. Without the gate, agents often skip the guide to save
|
||||
tokens and then produce JSON that fails validation — wasting turns on
|
||||
auto-fix loops.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
|
||||
from .helpers import require_guide_read
|
||||
from .models import ErrorResponse
|
||||
|
||||
|
||||
def _session_with_messages(messages: list[ChatMessage]) -> ChatSession:
|
||||
"""Build a minimal ChatSession whose ``messages`` matches *messages*."""
|
||||
session = MagicMock(spec=ChatSession)
|
||||
session.session_id = "test-session"
|
||||
session.messages = messages
|
||||
return session
|
||||
|
||||
|
||||
def test_no_messages_gate_fires():
|
||||
session = _session_with_messages([])
|
||||
result = require_guide_read(session, "create_agent")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "get_agent_building_guide" in result.message
|
||||
assert "create_agent" in result.message
|
||||
|
||||
|
||||
def test_user_message_only_gate_fires():
|
||||
session = _session_with_messages(
|
||||
[ChatMessage(role="user", content="build an agent")]
|
||||
)
|
||||
assert isinstance(require_guide_read(session, "create_agent"), ErrorResponse)
|
||||
|
||||
|
||||
def test_assistant_without_tool_calls_gate_fires():
|
||||
session = _session_with_messages(
|
||||
[ChatMessage(role="assistant", content="sure!", tool_calls=None)]
|
||||
)
|
||||
assert isinstance(require_guide_read(session, "create_agent"), ErrorResponse)
|
||||
|
||||
|
||||
def test_unrelated_tool_call_gate_fires():
|
||||
session = _session_with_messages(
|
||||
[
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[{"function": {"name": "find_block"}}],
|
||||
)
|
||||
]
|
||||
)
|
||||
assert isinstance(require_guide_read(session, "create_agent"), ErrorResponse)
|
||||
|
||||
|
||||
def test_guide_called_via_openai_shape_gate_passes():
|
||||
"""OpenAI/Anthropic wrap names under 'function': {'name': ...}."""
|
||||
session = _session_with_messages(
|
||||
[
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"function": {"name": "get_agent_building_guide"}},
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
assert require_guide_read(session, "create_agent") is None
|
||||
|
||||
|
||||
def test_guide_called_via_flat_shape_gate_passes():
|
||||
"""Some callers log tool calls with a flat {'name': ...} shape."""
|
||||
session = _session_with_messages(
|
||||
[
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[{"name": "get_agent_building_guide"}],
|
||||
)
|
||||
]
|
||||
)
|
||||
assert require_guide_read(session, "create_agent") is None
|
||||
|
||||
|
||||
def test_guide_earlier_in_history_still_passes():
|
||||
"""A guide call earlier in the session keeps the gate open for subsequent
|
||||
create/edit/validate/fix calls — the agent doesn't need to re-read it."""
|
||||
session = _session_with_messages(
|
||||
[
|
||||
ChatMessage(role="user", content="build X"),
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[{"function": {"name": "get_agent_building_guide"}}],
|
||||
),
|
||||
ChatMessage(role="user", content="also Y"),
|
||||
ChatMessage(role="assistant", content="working on it"),
|
||||
]
|
||||
)
|
||||
assert require_guide_read(session, "edit_agent") is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_name",
|
||||
["create_agent", "edit_agent", "validate_agent_graph", "fix_agent_graph"],
|
||||
)
|
||||
def test_tool_name_surfaced_in_error(tool_name: str):
|
||||
session = _session_with_messages([])
|
||||
result = require_guide_read(session, tool_name)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert tool_name in result.message
|
||||
@@ -8,6 +8,7 @@ from typing import Any
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.copilot.constants import MAX_TOOL_WAIT_SECONDS
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import execution_db, library_db
|
||||
from backend.data.execution import (
|
||||
@@ -39,7 +40,7 @@ class AgentOutputInput(BaseModel):
|
||||
store_slug: str = ""
|
||||
execution_id: str = ""
|
||||
run_time: str = "latest"
|
||||
wait_if_running: int = Field(default=0, ge=0, le=300)
|
||||
wait_if_running: int = Field(default=0, ge=0, le=MAX_TOOL_WAIT_SECONDS)
|
||||
show_execution_details: bool = False
|
||||
|
||||
@field_validator(
|
||||
@@ -148,9 +149,13 @@ class AgentOutputTool(BaseTool):
|
||||
},
|
||||
"wait_if_running": {
|
||||
"type": "integer",
|
||||
"description": "Max seconds to wait if still running (0-300). Returns current state on timeout.",
|
||||
"description": (
|
||||
"Max seconds to wait if still running "
|
||||
f"(0-{MAX_TOOL_WAIT_SECONDS}). "
|
||||
"Returns current state on timeout."
|
||||
),
|
||||
"minimum": 0,
|
||||
"maximum": 300,
|
||||
"maximum": MAX_TOOL_WAIT_SECONDS,
|
||||
},
|
||||
"show_execution_details": {
|
||||
"type": "boolean",
|
||||
|
||||
@@ -47,7 +47,7 @@ class BashExecTool(BaseTool):
|
||||
return (
|
||||
"Execute a Bash command or script. Shares filesystem with SDK file tools. "
|
||||
"Useful for scripts, data processing, and package installation. "
|
||||
"Killed after timeout (default 30s, max 120s)."
|
||||
"Killed after `timeout` seconds."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -61,8 +61,8 @@ class BashExecTool(BaseTool):
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Max seconds (default 30, max 120).",
|
||||
"default": 30,
|
||||
"description": "Timeout in seconds; raise for long-running commands.",
|
||||
"default": 120,
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
@@ -80,7 +80,7 @@ class BashExecTool(BaseTool):
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
command: str = "",
|
||||
timeout: int = 30,
|
||||
timeout: int = 120,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Run a bash command on E2B (if available) or in a bubblewrap sandbox.
|
||||
@@ -129,7 +129,7 @@ class BashExecTool(BaseTool):
|
||||
message=(
|
||||
"Execution timed out"
|
||||
if timed_out
|
||||
else f"Command executed (exit {exit_code})"
|
||||
else f"Command executed with status code {exit_code}"
|
||||
),
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
@@ -183,7 +183,7 @@ class BashExecTool(BaseTool):
|
||||
stdout = stdout.replace(secret, "[REDACTED]")
|
||||
stderr = stderr.replace(secret, "[REDACTED]")
|
||||
return BashExecResponse(
|
||||
message=f"Command executed on E2B (exit {result.exit_code})",
|
||||
message=f"Command executed with status code {result.exit_code}",
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
exit_code=result.exit_code,
|
||||
|
||||
@@ -35,12 +35,15 @@ class TestBashExecE2BTokenInjection:
|
||||
sandbox = _make_sandbox(stdout="ok")
|
||||
env_vars = {"GH_TOKEN": "gh-secret", "GITHUB_TOKEN": "gh-secret"}
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value=env_vars),
|
||||
) as mock_get_env, patch(
|
||||
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
|
||||
new=AsyncMock(return_value=None),
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value=env_vars),
|
||||
) as mock_get_env,
|
||||
patch(
|
||||
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
):
|
||||
result = await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
@@ -69,12 +72,15 @@ class TestBashExecE2BTokenInjection:
|
||||
"GIT_COMMITTER_EMAIL": "test@example.com",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value={}),
|
||||
), patch(
|
||||
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
|
||||
new=AsyncMock(return_value=identity),
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value={}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
|
||||
new=AsyncMock(return_value=identity),
|
||||
),
|
||||
):
|
||||
await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
@@ -97,12 +103,15 @@ class TestBashExecE2BTokenInjection:
|
||||
session = make_session(user_id=_USER)
|
||||
sandbox = _make_sandbox(stdout="ok")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value={}),
|
||||
), patch(
|
||||
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
|
||||
new=AsyncMock(return_value=None),
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value={}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
):
|
||||
await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
@@ -123,13 +132,16 @@ class TestBashExecE2BTokenInjection:
|
||||
session = make_session(user_id=_USER)
|
||||
sandbox = _make_sandbox(stdout="ok")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value={"GH_TOKEN": "should-not-appear"}),
|
||||
) as mock_get_env, patch(
|
||||
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
|
||||
new=AsyncMock(return_value=None),
|
||||
) as mock_get_identity:
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value={"GH_TOKEN": "should-not-appear"}),
|
||||
) as mock_get_env,
|
||||
patch(
|
||||
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
|
||||
new=AsyncMock(return_value=None),
|
||||
) as mock_get_identity,
|
||||
):
|
||||
result = await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
command="echo hi",
|
||||
|
||||
@@ -8,6 +8,7 @@ from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_generator.pipeline import fetch_library_agents, fix_validate_and_save
|
||||
from .base import BaseTool
|
||||
from .helpers import require_guide_read
|
||||
from .models import ErrorResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -23,8 +24,9 @@ class CreateAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a new agent from JSON (nodes + links). Validates, auto-fixes, and saves. "
|
||||
"If you haven't already, call get_agent_building_guide first."
|
||||
"Create a new agent from JSON (nodes + links). Validates, "
|
||||
"auto-fixes, and saves. "
|
||||
"Requires get_agent_building_guide first (refuses otherwise)."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -70,6 +72,10 @@ class CreateAgentTool(BaseTool):
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
guide_gate = require_guide_read(session, "create_agent")
|
||||
if guide_gate is not None:
|
||||
return guide_gate
|
||||
|
||||
if not agent_json:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
|
||||
@@ -8,6 +8,7 @@ from backend.copilot.model import ChatSession
|
||||
from .agent_generator import get_agent_as_json
|
||||
from .agent_generator.pipeline import fetch_library_agents, fix_validate_and_save
|
||||
from .base import BaseTool
|
||||
from .helpers import require_guide_read
|
||||
from .models import ErrorResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -24,7 +25,7 @@ class EditAgentTool(BaseTool):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit an existing agent. Validates, auto-fixes, and saves. "
|
||||
"If you haven't already, call get_agent_building_guide first."
|
||||
"Requires get_agent_building_guide first (refuses otherwise)."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -73,6 +74,10 @@ class EditAgentTool(BaseTool):
|
||||
library_agent_ids = []
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
guide_gate = require_guide_read(session, "edit_agent")
|
||||
if guide_gate is not None:
|
||||
return guide_gate
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the agent ID to edit.",
|
||||
|
||||
@@ -42,6 +42,10 @@ COPILOT_EXCLUDED_BLOCK_IDS = {
|
||||
# OrchestratorBlock - dynamically discovers downstream blocks via graph topology;
|
||||
# usable in agent graphs (guide hardcodes its ID) but cannot run standalone.
|
||||
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||
# AutoPilotBlock - has dedicated run_sub_session tool with async start +
|
||||
# poll lifecycle. Calling it via run_block would block the parent stream
|
||||
# for the sub-AutoPilot's entire runtime (15-45+ min typical).
|
||||
"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_generator.validation import AgentFixer, AgentValidator, get_blocks_as_dicts
|
||||
from .base import BaseTool
|
||||
from .helpers import require_guide_read
|
||||
from .models import ErrorResponse, FixResultResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,7 +26,8 @@ class FixAgentGraphTool(BaseTool):
|
||||
"Auto-fix common agent JSON issues: missing/invalid UUIDs, StoreValueBlock prerequisites, "
|
||||
"double curly brace escaping, AddToList/AddToDictionary prerequisites, credentials, "
|
||||
"node spacing, AI model defaults, link static properties, and type mismatches. "
|
||||
"Returns fixed JSON and list of fixes applied."
|
||||
"Returns fixed JSON and list of fixes applied. "
|
||||
"Requires get_agent_building_guide first (refuses otherwise)."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -56,6 +58,10 @@ class FixAgentGraphTool(BaseTool):
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
guide_gate = require_guide_read(session, "fix_agent_graph")
|
||||
if guide_gate is not None:
|
||||
return guide_gate
|
||||
|
||||
if not agent_json or not isinstance(agent_json, dict):
|
||||
return ErrorResponse(
|
||||
message="Please provide a valid agent JSON object.",
|
||||
|
||||
@@ -43,8 +43,10 @@ class GetAgentBuildingGuideTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Get the agent JSON building guide (nodes, links, AgentExecutorBlock, MCPToolBlock usage, "
|
||||
"and the create->dry-run->fix iterative workflow). Call before generating agent JSON."
|
||||
"Agent JSON building guide (nodes, links, AgentExecutorBlock, "
|
||||
"MCPToolBlock, iterative create->dry-run->fix flow). REQUIRED "
|
||||
"before create_agent / edit_agent / validate_agent_graph / "
|
||||
"fix_agent_graph — they refuse until called once per session."
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -0,0 +1,305 @@
|
||||
"""Poll / wait on / cancel a sub-AutoPilot started by ``run_sub_session``.
|
||||
|
||||
Companion to :mod:`run_sub_session`. Operates on the sub's
|
||||
``ChatSession`` directly — there is no separate registry. Ownership is
|
||||
re-verified on every call by loading the ChatSession and comparing its
|
||||
``user_id`` against the authenticated caller.
|
||||
|
||||
* **Wait** — subscribe to ``stream_registry`` for the session and drain
|
||||
until ``StreamFinish`` / ``StreamError`` (terminal) or the per-call
|
||||
cap fires. On terminal, the aggregated :class:`SessionResult` comes
|
||||
back in memory — no DB round-trip for the response content.
|
||||
* **Just check** — ``wait_if_running=0`` skips the subscription. If the
|
||||
sub's last assistant message already looks terminal, returns
|
||||
``completed`` with that content.
|
||||
* **Cancel** — fan out a ``CancelCoPilotEvent`` on the shared cancel
|
||||
exchange. Whichever worker is running the sub breaks out of its
|
||||
stream and finalises the session as ``failed``.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task
|
||||
from backend.copilot.model import ChatSession, get_chat_session
|
||||
from backend.copilot.sdk.session_waiter import (
|
||||
SessionOutcome,
|
||||
SessionResult,
|
||||
wait_for_session_result,
|
||||
)
|
||||
from backend.copilot.sdk.stream_accumulator import ToolCallEntry
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
ErrorResponse,
|
||||
SubSessionProgressSnapshot,
|
||||
SubSessionStatusResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .run_sub_session import (
|
||||
MAX_SUB_SESSION_WAIT_SECONDS,
|
||||
_sub_session_link,
|
||||
response_from_outcome,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cap on how many recent messages we echo back in a progress snapshot.
|
||||
_PROGRESS_MESSAGE_LIMIT = 5
|
||||
_PROGRESS_CONTENT_PREVIEW_CHARS = 400
|
||||
|
||||
|
||||
class GetSubSessionResultTool(BaseTool):
|
||||
"""Wait for, inspect, or cancel a sub-AutoPilot."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_sub_session_result"
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Poll / wait / cancel a sub-AutoPilot from run_sub_session. "
|
||||
f"Waits up to wait_if_running sec (max {MAX_SUB_SESSION_WAIT_SECONDS}); "
|
||||
"cancel=true aborts; include_progress=true returns recent messages "
|
||||
"from the still-running sub. Works across turns."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sub_session_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The sub's session_id returned by run_sub_session "
|
||||
"(also accepted: sub_autopilot_session_id — same value)."
|
||||
),
|
||||
},
|
||||
"wait_if_running": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
f"Seconds to wait. 0 = just check. Clamped to "
|
||||
f"{MAX_SUB_SESSION_WAIT_SECONDS}."
|
||||
),
|
||||
"default": 60,
|
||||
},
|
||||
"cancel": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Cancel the sub; takes precedence over wait_if_running."
|
||||
),
|
||||
"default": False,
|
||||
},
|
||||
"include_progress": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Populate progress.last_messages when status=running."
|
||||
),
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": ["sub_session_id"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
*,
|
||||
sub_session_id: str = "",
|
||||
wait_if_running: int = 60,
|
||||
cancel: bool = False,
|
||||
include_progress: bool = False,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
inner_session_id = sub_session_id.strip()
|
||||
if not inner_session_id:
|
||||
return ErrorResponse(
|
||||
message="sub_session_id is required",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
if user_id is None:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
# Ownership check on every call — loads the ChatSession and
|
||||
# confirms the caller owns it. Returning the same "not found"
|
||||
# shape for "doesn't exist" and "belongs to someone else" avoids
|
||||
# leaking session existence.
|
||||
sub = await get_chat_session(inner_session_id)
|
||||
if sub is None or sub.user_id != user_id:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"No sub-session with id {inner_session_id}. It may have "
|
||||
"never existed or belongs to another user."
|
||||
),
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
started_at = time.monotonic()
|
||||
|
||||
if cancel:
|
||||
# Fan out the cancel event. Whichever worker is running the
|
||||
# sub will break out of its stream and finalise the session
|
||||
# as failed. Return "cancelled" immediately; the sub may
|
||||
# still emit a little more output before the worker notices,
|
||||
# but the agent doesn't need to wait for that.
|
||||
await enqueue_cancel_task(inner_session_id)
|
||||
return SubSessionStatusResponse(
|
||||
message="Sub-AutoPilot cancel requested.",
|
||||
session_id=session.session_id,
|
||||
status="cancelled",
|
||||
sub_session_id=inner_session_id,
|
||||
sub_autopilot_session_id=inner_session_id,
|
||||
sub_autopilot_session_link=_sub_session_link(inner_session_id),
|
||||
elapsed_seconds=0.0,
|
||||
)
|
||||
|
||||
# If a turn is currently running for this session (stream registry
|
||||
# meta shows status=running), we can NOT short-circuit on the
|
||||
# persisted last assistant message — that message belongs to a
|
||||
# PRIOR turn, and surfacing it here would hand the caller stale
|
||||
# data while the new turn is mid-flight (sentry r3105409601).
|
||||
# Only short-circuit when there's no active turn AND the last
|
||||
# persisted message already looks terminal.
|
||||
effective_wait = max(0, min(wait_if_running, MAX_SUB_SESSION_WAIT_SECONDS))
|
||||
registry_session = await stream_registry.get_session(inner_session_id)
|
||||
turn_in_flight = registry_session is not None and (
|
||||
getattr(registry_session, "status", "") == "running"
|
||||
)
|
||||
terminal_result = None if turn_in_flight else _already_terminal_result(sub)
|
||||
outcome: SessionOutcome
|
||||
result: SessionResult
|
||||
if terminal_result is not None:
|
||||
outcome, result = "completed", terminal_result
|
||||
elif effective_wait > 0:
|
||||
outcome, result = await wait_for_session_result(
|
||||
session_id=inner_session_id,
|
||||
user_id=user_id,
|
||||
timeout=effective_wait,
|
||||
)
|
||||
else:
|
||||
outcome, result = "running", SessionResult()
|
||||
|
||||
elapsed = time.monotonic() - started_at
|
||||
|
||||
if outcome == "running" and include_progress:
|
||||
# Running + caller wants progress — hand-assemble the response
|
||||
# with the progress snapshot attached. response_from_outcome
|
||||
# doesn't carry progress, so we build the response here.
|
||||
progress = await _build_progress_snapshot(inner_session_id)
|
||||
link = _sub_session_link(inner_session_id)
|
||||
return SubSessionStatusResponse(
|
||||
message=(
|
||||
f"Sub-AutoPilot still running after {elapsed:.0f}s."
|
||||
f"{f' Watch live at {link}.' if link else ''} "
|
||||
"Call again to keep waiting, or cancel=true to abort."
|
||||
),
|
||||
session_id=session.session_id,
|
||||
status="running",
|
||||
sub_session_id=inner_session_id,
|
||||
sub_autopilot_session_id=inner_session_id,
|
||||
sub_autopilot_session_link=link,
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
progress=progress,
|
||||
)
|
||||
|
||||
return response_from_outcome(
|
||||
outcome=outcome,
|
||||
result=result,
|
||||
inner_session_id=inner_session_id,
|
||||
parent_session_id=session.session_id,
|
||||
elapsed=elapsed,
|
||||
)
|
||||
|
||||
|
||||
def _already_terminal_result(sub: ChatSession) -> SessionResult | None:
|
||||
"""Rebuild the aggregated result from the sub's persisted last turn,
|
||||
when the last message is a terminal assistant message.
|
||||
|
||||
Lets ``get_sub_session_result`` short-circuit the subscribe+wait
|
||||
when the agent polls well after the sub actually finished (a common
|
||||
case when the user pauses and later asks "what's the result?").
|
||||
Returns ``None`` if the last message isn't terminal.
|
||||
"""
|
||||
if not sub.messages:
|
||||
return None
|
||||
last = sub.messages[-1]
|
||||
if last.role != "assistant":
|
||||
return None
|
||||
if not last.content and not last.tool_calls:
|
||||
return None
|
||||
result = SessionResult()
|
||||
result.response_text = last.content or ""
|
||||
# Persisted tool calls are OpenAI-shape dicts; translate to
|
||||
# ToolCallEntry so the downstream ``response_from_outcome`` can
|
||||
# ``.model_dump()`` them uniformly with the live-drain path.
|
||||
for tc in last.tool_calls or []:
|
||||
fn = tc.get("function") or {}
|
||||
result.tool_calls.append(
|
||||
ToolCallEntry(
|
||||
tool_call_id=tc.get("id", ""),
|
||||
tool_name=fn.get("name") or tc.get("name") or "",
|
||||
input=fn.get("arguments") or tc.get("arguments") or tc.get("input"),
|
||||
output=tc.get("output"),
|
||||
success=tc.get("success"),
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def _build_progress_snapshot(
|
||||
inner_session_id: str | None,
|
||||
) -> SubSessionProgressSnapshot | None:
|
||||
"""Read the sub's ChatSession and return a preview of recent messages.
|
||||
|
||||
Returns ``None`` silently on lookup failure — progress is best-effort;
|
||||
missing progress shouldn't abort the normal ``still running`` response.
|
||||
"""
|
||||
if not inner_session_id:
|
||||
return None
|
||||
try:
|
||||
sub = await get_chat_session(inner_session_id)
|
||||
if sub is None:
|
||||
return None
|
||||
messages = list(sub.messages)
|
||||
except Exception as exc: # best-effort peek
|
||||
logger.debug(
|
||||
"Progress snapshot unavailable for sub %s: %s",
|
||||
inner_session_id,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
tail = messages[-_PROGRESS_MESSAGE_LIMIT:]
|
||||
previews: list[dict[str, Any]] = []
|
||||
for msg in tail:
|
||||
content = getattr(msg, "content", "") or ""
|
||||
if not isinstance(content, str):
|
||||
try:
|
||||
content = json.dumps(content, default=str)
|
||||
except (TypeError, ValueError):
|
||||
content = str(content)
|
||||
if len(content) > _PROGRESS_CONTENT_PREVIEW_CHARS:
|
||||
content = content[:_PROGRESS_CONTENT_PREVIEW_CHARS] + "…"
|
||||
previews.append(
|
||||
{
|
||||
"role": getattr(msg, "role", "unknown"),
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
return SubSessionProgressSnapshot(
|
||||
message_count=len(messages),
|
||||
last_messages=previews,
|
||||
)
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Shared helpers for chat tools."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
@@ -14,6 +15,7 @@ from backend.copilot.constants import (
|
||||
COPILOT_NODE_EXEC_ID_SEPARATOR,
|
||||
COPILOT_NODE_PREFIX,
|
||||
COPILOT_SESSION_PREFIX,
|
||||
MAX_TOOL_WAIT_SECONDS,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.sdk.file_ref import FileRefExpansionError, expand_file_refs_in_args
|
||||
@@ -85,6 +87,71 @@ def get_inputs_from_schema(
|
||||
return results
|
||||
|
||||
|
||||
async def _charge_block_credits(
|
||||
_credit_db: Any,
|
||||
*,
|
||||
user_id: str,
|
||||
block_name: str,
|
||||
block_id: str,
|
||||
node_exec_id: str,
|
||||
cost: int,
|
||||
cost_filter: dict[str, Any],
|
||||
synthetic_graph_id: str,
|
||||
synthetic_node_id: str,
|
||||
) -> None:
|
||||
"""Charge credits for a block execution and log any billing leak.
|
||||
|
||||
Centralised so the normal-path charge and the cancellation-recovery charge
|
||||
(see ``execute_block``'s finally) use the same metadata and the same
|
||||
leak-logging contract.
|
||||
"""
|
||||
try:
|
||||
await _credit_db.spend_credits(
|
||||
user_id=user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
block_id=block_id,
|
||||
block=block_name,
|
||||
input=cost_filter,
|
||||
reason="copilot_block_execution",
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
# Block already executed (with possible side effects). Never
|
||||
# return ErrorResponse here — the user received output and
|
||||
# deserves it. Log the billing failure for reconciliation.
|
||||
leak_type = (
|
||||
"INSUFFICIENT_BALANCE"
|
||||
if isinstance(e, InsufficientBalanceError)
|
||||
else "UNEXPECTED_ERROR"
|
||||
)
|
||||
logger.error(
|
||||
"BILLING_LEAK[%s]: block executed but credit charge failed — "
|
||||
"user_id=%s, block_id=%s, node_exec_id=%s, cost=%s: %s",
|
||||
leak_type,
|
||||
user_id,
|
||||
block_id,
|
||||
node_exec_id,
|
||||
cost,
|
||||
e,
|
||||
extra={
|
||||
"json_fields": {
|
||||
"billing_leak": True,
|
||||
"leak_type": leak_type,
|
||||
"user_id": user_id,
|
||||
"cost": str(cost),
|
||||
}
|
||||
},
|
||||
)
|
||||
# Intentionally swallow. Block already executed with possible side
|
||||
# effects; the caller must still return BlockOutputResponse. The
|
||||
# BILLING_LEAK log above is the signal for reconciliation.
|
||||
|
||||
|
||||
async def execute_block(
|
||||
*,
|
||||
block: AnyBlockSchema,
|
||||
@@ -210,67 +277,97 @@ async def execute_block(
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
# Execute the block under the shared MCP wait cap. A block is
|
||||
# expected to finish in MAX_TOOL_WAIT_SECONDS; if it doesn't, the
|
||||
# MCP handler would block the stream close to the idle timeout.
|
||||
# wait_for cancels the generator on timeout, but the finally below
|
||||
# still settles billing via asyncio.shield — external side effects
|
||||
# may already have landed and the user should be charged for them.
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
input_data,
|
||||
**exec_kwargs,
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
charge_handled = False
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
_collect_block_outputs(block, input_data, exec_kwargs, outputs),
|
||||
timeout=MAX_TOOL_WAIT_SECONDS,
|
||||
)
|
||||
|
||||
# Charge credits for block execution
|
||||
if has_cost:
|
||||
try:
|
||||
await _credit_db.spend_credits(
|
||||
user_id=user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
# Normal (non-cancelled) path. Mark charge_handled BEFORE the
|
||||
# await so an outer cancellation landing mid-charge can't race
|
||||
# the finally block into a double-charge. asyncio.shield keeps
|
||||
# the spend running to completion even if the outer awaitable
|
||||
# is cancelled.
|
||||
if has_cost:
|
||||
charge_handled = True
|
||||
await asyncio.shield(
|
||||
_charge_block_credits(
|
||||
_credit_db,
|
||||
user_id=user_id,
|
||||
block_name=block.name,
|
||||
block_id=block_id,
|
||||
block=block.name,
|
||||
input=cost_filter,
|
||||
reason="copilot_block_execution",
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
# Block already executed (with possible side effects). Never
|
||||
# return ErrorResponse here — the user received output and
|
||||
# deserves it. Log the billing failure for reconciliation.
|
||||
leak_type = (
|
||||
"INSUFFICIENT_BALANCE"
|
||||
if isinstance(e, InsufficientBalanceError)
|
||||
else "UNEXPECTED_ERROR"
|
||||
)
|
||||
logger.error(
|
||||
"BILLING_LEAK[%s]: block executed but credit charge failed — "
|
||||
"user_id=%s, block_id=%s, node_exec_id=%s, cost=%s: %s",
|
||||
leak_type,
|
||||
user_id,
|
||||
block_id,
|
||||
node_exec_id,
|
||||
cost,
|
||||
e,
|
||||
extra={
|
||||
"json_fields": {
|
||||
"billing_leak": True,
|
||||
"leak_type": leak_type,
|
||||
"user_id": user_id,
|
||||
"cost": str(cost),
|
||||
}
|
||||
},
|
||||
node_exec_id=node_exec_id,
|
||||
cost=cost,
|
||||
cost_filter=cost_filter,
|
||||
synthetic_graph_id=synthetic_graph_id,
|
||||
synthetic_node_id=synthetic_node_id,
|
||||
)
|
||||
)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# Structured record of tool-call timeouts (SECRT-2247 part 3).
|
||||
# Grep prod logs for `copilot_tool_timeout` to find tools that
|
||||
# keep hitting the cap — candidates for prompt tuning or
|
||||
# escalation to the async start+poll pattern.
|
||||
logger.warning(
|
||||
"copilot_tool_timeout tool=run_block block=%s block_id=%s "
|
||||
"input_keys=%s user=%s session=%s cap_s=%d",
|
||||
block.name,
|
||||
block_id,
|
||||
sorted(input_data.keys()),
|
||||
user_id,
|
||||
session_id,
|
||||
MAX_TOOL_WAIT_SECONDS,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' exceeded the "
|
||||
f"{MAX_TOOL_WAIT_SECONDS}s single-tool wait cap and was "
|
||||
"cancelled. Long-running work should go through run_agent "
|
||||
"(graph executions) or run_sub_session (sub-AutoPilot "
|
||||
"tasks) — those use async start+poll so nothing blocks "
|
||||
"the chat stream."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
finally:
|
||||
# Sentry r3105079148: asyncio.wait_for raises CancelledError into
|
||||
# the generator. Normal `except Exception` doesn't catch it, so
|
||||
# without this finally a cancelled block would skip credit
|
||||
# charging entirely while external side effects still landed.
|
||||
# Only run when the normal-path charge was NOT reached (the flag
|
||||
# is set before the await, so any cancellation during charge still
|
||||
# sets it and avoids double-billing — r3105216985).
|
||||
if has_cost and outputs and not charge_handled:
|
||||
await asyncio.shield(
|
||||
_charge_block_credits(
|
||||
_credit_db,
|
||||
user_id=user_id,
|
||||
block_name=block.name,
|
||||
block_id=block_id,
|
||||
node_exec_id=node_exec_id,
|
||||
cost=cost,
|
||||
cost_filter=cost_filter,
|
||||
synthetic_graph_id=synthetic_graph_id,
|
||||
synthetic_node_id=synthetic_node_id,
|
||||
)
|
||||
)
|
||||
|
||||
except BlockError as e:
|
||||
logger.warning("Block execution failed: %s", e)
|
||||
@@ -288,6 +385,23 @@ async def execute_block(
|
||||
)
|
||||
|
||||
|
||||
async def _collect_block_outputs(
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any],
|
||||
exec_kwargs: dict[str, Any],
|
||||
outputs: dict[str, list[Any]],
|
||||
) -> None:
|
||||
"""Drive ``block.execute`` and append each emitted pair to *outputs*.
|
||||
|
||||
Extracted so ``asyncio.wait_for`` can wrap exactly the generator-
|
||||
consumption step; callers read ``outputs`` afterwards (including from
|
||||
the cancellation path) to decide whether the block produced enough
|
||||
side-effects to warrant billing.
|
||||
"""
|
||||
async for output_name, output_data in block.execute(input_data, **exec_kwargs):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
|
||||
async def resolve_block_credentials(
|
||||
user_id: str,
|
||||
block: AnyBlockSchema,
|
||||
@@ -655,3 +769,51 @@ def _resolve_discriminated_credentials(
|
||||
resolved[field_name] = effective_field_info
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent-generation gate
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Tools that produce or modify agent JSON (create_agent, edit_agent,
|
||||
# validate_agent_graph, fix_agent_graph) require the parent agent to have
|
||||
# read the agent-building guide first — otherwise it tends to generate
|
||||
# JSON that doesn't match the current block schemas, link semantics, or
|
||||
# AgentExecutorBlock conventions, then waste turns fixing validation
|
||||
# errors. ``require_guide_read`` returns an ``ErrorResponse`` the caller
|
||||
# should short-circuit with, or ``None`` when the guide has been read.
|
||||
|
||||
|
||||
_AGENT_GUIDE_TOOL_NAME = "get_agent_building_guide"
|
||||
|
||||
|
||||
def _guide_read_in_session(session: ChatSession) -> bool:
|
||||
"""True if this session's assistant messages include a guide tool call."""
|
||||
for msg in reversed(session.messages):
|
||||
if msg.role != "assistant" or not msg.tool_calls:
|
||||
continue
|
||||
for tc in msg.tool_calls:
|
||||
name = tc.get("function", {}).get("name") or tc.get("name")
|
||||
if name == _AGENT_GUIDE_TOOL_NAME:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def require_guide_read(session: ChatSession, tool_name: str):
|
||||
"""Return an ErrorResponse if the guide hasn't been loaded this session.
|
||||
|
||||
Import inline to keep ``helpers.py`` free of tool-response imports.
|
||||
"""
|
||||
from .models import ErrorResponse # noqa: PLC0415 — avoid circular import
|
||||
|
||||
if _guide_read_in_session(session):
|
||||
return None
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Call get_agent_building_guide first, then retry {tool_name}. "
|
||||
"The guide documents required block ids, input/output schemas, "
|
||||
"link semantics, and AgentExecutorBlock / MCPToolBlock usage — "
|
||||
"generating agent JSON without it produces schema mismatches."
|
||||
),
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
@@ -259,6 +259,90 @@ class ErrorResponse(ToolResponseBase):
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SubSessionProgressSnapshot(BaseModel):
|
||||
"""Mid-flight snapshot of a running sub-AutoPilot.
|
||||
|
||||
Returned under ``progress`` on :class:`SubSessionStatusResponse` when the
|
||||
caller passes ``include_progress=true`` while the sub is still running.
|
||||
"""
|
||||
|
||||
message_count: int = Field(
|
||||
description="Total messages in the sub's ChatSession so far.",
|
||||
)
|
||||
last_messages: list[dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Up to the last 5 messages (role + truncated content) from the "
|
||||
"sub's ChatSession — lets the agent report intermediate progress."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SubSessionStatusResponse(ToolResponseBase):
|
||||
"""Status / result of a sub-AutoPilot run started by ``run_sub_session``.
|
||||
|
||||
Returned by both ``run_sub_session`` (synchronously when the sub finishes
|
||||
within ``wait_for_result``, else with ``status='running'``) and
|
||||
``get_sub_session_result`` when the agent polls.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.MCP_TOOL_OUTPUT
|
||||
status: Literal["running", "completed", "cancelled", "error", "queued"] = Field(
|
||||
description=(
|
||||
"Current state of the sub-AutoPilot run. ``queued`` means the "
|
||||
"target session already had a turn in flight, so the message was "
|
||||
"pushed onto its pending buffer and will be picked up by the "
|
||||
"existing turn on its next drain."
|
||||
),
|
||||
)
|
||||
sub_session_id: str = Field(
|
||||
description=(
|
||||
"Opaque id for this run. Pass to ``get_sub_session_result`` or "
|
||||
"``run_sub_session(cancel=true, ...)`` to interact with it."
|
||||
),
|
||||
)
|
||||
response: str | None = Field(
|
||||
default=None,
|
||||
description="Assistant response text when status=completed.",
|
||||
)
|
||||
sub_autopilot_session_id: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The session_id of the sub-AutoPilot conversation. Use with "
|
||||
"``run_sub_session(..., sub_autopilot_session_id=<this>)`` "
|
||||
"to continue it."
|
||||
),
|
||||
)
|
||||
sub_autopilot_session_link: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Relative URL the user can click to open the sub-AutoPilot "
|
||||
"conversation in the CoPilot UI. Always set when "
|
||||
"``sub_autopilot_session_id`` is set."
|
||||
),
|
||||
)
|
||||
tool_calls: list[dict[str, Any]] | None = Field(
|
||||
default=None,
|
||||
description="Tool calls made during the sub-AutoPilot run.",
|
||||
)
|
||||
error: str | None = Field(
|
||||
default=None,
|
||||
description="Error message when status=error.",
|
||||
)
|
||||
elapsed_seconds: float | None = Field(
|
||||
default=None,
|
||||
description="How long the sub-AutoPilot has been running (or took).",
|
||||
)
|
||||
progress: SubSessionProgressSnapshot | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Mid-flight progress snapshot. Populated only when "
|
||||
"get_sub_session_result is called with include_progress=true "
|
||||
"and the sub is still running."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class InputValidationErrorResponse(ToolResponseBase):
|
||||
"""Response when run_agent receives unknown input fields."""
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.constants import MAX_TOOL_WAIT_SECONDS
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled
|
||||
from backend.data.db_accessors import graph_db, library_db, user_db
|
||||
@@ -71,7 +72,7 @@ class RunAgentInput(BaseModel):
|
||||
schedule_name: str = ""
|
||||
cron: str = ""
|
||||
timezone: str = "UTC"
|
||||
wait_for_result: int = Field(default=0, ge=0, le=300)
|
||||
wait_for_result: int = Field(default=0, ge=0, le=MAX_TOOL_WAIT_SECONDS)
|
||||
dry_run: bool = Field(default=False)
|
||||
|
||||
@field_validator(
|
||||
@@ -150,9 +151,12 @@ class RunAgentTool(BaseTool):
|
||||
},
|
||||
"wait_for_result": {
|
||||
"type": "integer",
|
||||
"description": "Max seconds to wait for completion (0-300).",
|
||||
"description": (
|
||||
"Max seconds to wait for completion "
|
||||
f"(0-{MAX_TOOL_WAIT_SECONDS})."
|
||||
),
|
||||
"minimum": 0,
|
||||
"maximum": 300,
|
||||
"maximum": MAX_TOOL_WAIT_SECONDS,
|
||||
},
|
||||
"dry_run": {
|
||||
"type": "boolean",
|
||||
|
||||
@@ -140,7 +140,9 @@ class TestRunBlockFiltering:
|
||||
async def test_block_denied_by_permissions_returns_error(self):
|
||||
"""A block denied by CopilotPermissions returns an ErrorResponse."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block_id = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
# NB: must not match any id in COPILOT_EXCLUDED_BLOCK_IDS — we want
|
||||
# the permissions guard to fire, not the exclusion guard.
|
||||
block_id = "11111111-2222-3333-4444-555555555555"
|
||||
standard_block = make_mock_block(block_id, "HTTP Request", BlockType.STANDARD)
|
||||
|
||||
perms = CopilotPermissions(blocks=[block_id], blocks_exclude=True)
|
||||
@@ -645,3 +647,230 @@ class TestRunBlockSensitiveAction:
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert response.success is True
|
||||
|
||||
|
||||
class TestExecuteBlockTimeout:
|
||||
"""``execute_block`` caps the block's generator consumption at
|
||||
MAX_TOOL_WAIT_SECONDS and must:
|
||||
1. Return an actionable ErrorResponse pointing at run_agent / run_sub_session.
|
||||
2. Log a ``copilot_tool_timeout`` warning (SECRT-2247 part 3).
|
||||
3. Still charge credits when outputs were produced before the timeout
|
||||
(sentry r3105079148 — cancellation must not leak billing)."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_timeout_returns_error_and_logs(self, caplog):
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from backend.copilot.tools.helpers import execute_block
|
||||
|
||||
mock_block = MagicMock()
|
||||
mock_block.name = "SlowBlock"
|
||||
mock_block.id = "slow-block-id"
|
||||
mock_block.input_schema = MagicMock()
|
||||
mock_block.input_schema.jsonschema.return_value = {
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {}
|
||||
|
||||
async def _hang(_input, **_kwargs):
|
||||
await asyncio.sleep(10)
|
||||
yield "never", "never"
|
||||
|
||||
mock_block.execute = _hang
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(0, {}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.MAX_TOOL_WAIT_SECONDS",
|
||||
0.05,
|
||||
),
|
||||
caplog.at_level(logging.WARNING, logger="backend.copilot.tools.helpers"),
|
||||
):
|
||||
response = await execute_block(
|
||||
block=mock_block,
|
||||
block_id="slow-block-id",
|
||||
input_data={"x": 1},
|
||||
user_id="u-1",
|
||||
session_id="s-1",
|
||||
node_exec_id="n-1",
|
||||
matched_credentials={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "single-tool wait cap" in response.message
|
||||
assert "run_agent" in response.message
|
||||
assert any(
|
||||
"copilot_tool_timeout" in record.getMessage() for record in caplog.records
|
||||
), "timeout must emit a grep-friendly log line for SECRT-2247 part 3"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_cancellation_after_output_still_charges_credits(self):
|
||||
"""Regression for sentry r3105079148 — wait_for's CancelledError
|
||||
bypassed credit charging; fix uses a shielded finally. One output
|
||||
emitted, then timeout: spend_credits must still be called once."""
|
||||
import asyncio
|
||||
|
||||
from backend.copilot.tools.helpers import execute_block
|
||||
|
||||
mock_block = MagicMock()
|
||||
mock_block.name = "CostlyBlock"
|
||||
mock_block.id = "costly-block-id"
|
||||
mock_block.input_schema = MagicMock()
|
||||
mock_block.input_schema.jsonschema.return_value = {
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {}
|
||||
|
||||
# Generator: emit ONE output (simulating a side-effectful API call),
|
||||
# then hang — execute_block's internal wait_for cancels us.
|
||||
async def _one_output_then_hang(_input, **_kw):
|
||||
yield "result", "side effect happened"
|
||||
await asyncio.sleep(10)
|
||||
yield "extra", "should never arrive"
|
||||
|
||||
mock_block.execute = _one_output_then_hang
|
||||
|
||||
charged: dict[str, object] = {}
|
||||
|
||||
class _FakeCreditDB:
|
||||
async def get_credits(self, _user_id: str) -> int:
|
||||
return 10_000
|
||||
|
||||
async def spend_credits(self, **kwargs):
|
||||
charged["last"] = kwargs
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.credit_db",
|
||||
return_value=_FakeCreditDB(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(5, {}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.MAX_TOOL_WAIT_SECONDS",
|
||||
0.2,
|
||||
),
|
||||
):
|
||||
response = await execute_block(
|
||||
block=mock_block,
|
||||
block_id="costly-block-id",
|
||||
input_data={},
|
||||
user_id="u-42",
|
||||
session_id="s-42",
|
||||
node_exec_id="n-42",
|
||||
matched_credentials={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
# Cap fired → response is the timeout ErrorResponse
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "single-tool wait cap" in response.message
|
||||
|
||||
# Critical: billing ran via the shielded finally despite the cancellation
|
||||
assert charged.get("last") is not None, (
|
||||
"Credits were NOT charged after cancellation — billing leak "
|
||||
"(sentry r3105079148)"
|
||||
)
|
||||
assert charged["last"]["user_id"] == "u-42"
|
||||
assert charged["last"]["cost"] == 5
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_double_charge_on_cancellation_during_charge(self):
|
||||
"""Regression for sentry r3105216985 — if the caller cancels during
|
||||
the normal-path credit charge, the finally must NOT charge a second
|
||||
time. The fix marks charge_handled BEFORE awaiting spend_credits."""
|
||||
import asyncio
|
||||
|
||||
from backend.copilot.tools.helpers import execute_block
|
||||
|
||||
mock_block = MagicMock()
|
||||
mock_block.name = "OnceOnlyBlock"
|
||||
mock_block.id = "once-only-id"
|
||||
mock_block.input_schema = MagicMock()
|
||||
mock_block.input_schema.jsonschema.return_value = {
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {}
|
||||
|
||||
async def _one_then_done(_input, **_kw):
|
||||
yield "result", "done"
|
||||
|
||||
mock_block.execute = _one_then_done
|
||||
|
||||
spend_calls: list[dict] = []
|
||||
|
||||
class _CountingCreditDB:
|
||||
async def get_credits(self, _user_id: str) -> int:
|
||||
return 10_000
|
||||
|
||||
async def spend_credits(self, **kwargs):
|
||||
# Cooperative suspension so an outer cancellation can
|
||||
# theoretically interleave — shield should still make this
|
||||
# complete exactly once.
|
||||
await asyncio.sleep(0)
|
||||
spend_calls.append(kwargs)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.credit_db",
|
||||
return_value=_CountingCreditDB(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(7, {}),
|
||||
),
|
||||
):
|
||||
response = await execute_block(
|
||||
block=mock_block,
|
||||
block_id="once-only-id",
|
||||
input_data={},
|
||||
user_id="u-single",
|
||||
session_id="s-single",
|
||||
node_exec_id="n-single",
|
||||
matched_credentials={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert response.success is True
|
||||
assert len(spend_calls) == 1, (
|
||||
f"spend_credits must be called exactly once, got {len(spend_calls)} "
|
||||
"(double-charge — sentry r3105216985)"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,258 @@
|
||||
"""Start a sub-AutoPilot conversation via the copilot_executor queue.
|
||||
|
||||
Mirror-image of ``run_agent`` + ``view_agent_output`` for copilot turns:
|
||||
|
||||
1. The tool creates (or validates ownership of) an inner ``ChatSession``
|
||||
and calls :func:`run_copilot_turn_via_queue` — the shared primitive
|
||||
that creates the stream-registry session meta, enqueues a
|
||||
``CoPilotExecutionEntry``, and waits on the Redis stream until the
|
||||
terminal event arrives or the cap fires.
|
||||
2. Any available ``copilot_executor`` worker claims the job, runs
|
||||
the SDK stream to completion, and publishes the final
|
||||
``StreamFinish`` event on the session's Redis stream.
|
||||
3. If the terminal event arrives in the wait window, the aggregated
|
||||
:class:`SessionResult` (response text, tool calls, usage) comes back
|
||||
in memory — no DB round-trip. Otherwise the tool returns
|
||||
``status="running"`` + the sub's ``session_id`` and the agent polls
|
||||
via :mod:`get_sub_session_result`.
|
||||
|
||||
Compared to the prior in-process ``asyncio.Task`` implementation this
|
||||
gives us deploy/crash resilience, natural load balancing across
|
||||
workers, and a uniform conversation model — a sub is just another
|
||||
copilot turn routed through the same queue and event bus as every
|
||||
other turn.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.constants import MAX_TOOL_WAIT_SECONDS
|
||||
from backend.copilot.context import get_current_permissions
|
||||
from backend.copilot.model import ChatSession, create_chat_session, get_chat_session
|
||||
from backend.copilot.sdk.session_waiter import (
|
||||
SessionOutcome,
|
||||
SessionResult,
|
||||
run_copilot_turn_via_queue,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, SubSessionStatusResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Max wait for a single run_sub_session / get_sub_session_result call.
|
||||
# Shared with every other long-running tool so the stream idle timeout's
|
||||
# 2x headroom holds uniformly.
|
||||
MAX_SUB_SESSION_WAIT_SECONDS = MAX_TOOL_WAIT_SECONDS
|
||||
|
||||
|
||||
class RunSubSessionTool(BaseTool):
|
||||
"""Delegate a task to a fresh sub-AutoPilot via the copilot_executor queue."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "run_sub_session"
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Delegate a task to a fresh sub-AutoPilot. Runs on the copilot "
|
||||
"executor queue — survives tab-close AND worker restarts. Waits "
|
||||
f"up to wait_for_result sec (max {MAX_SUB_SESSION_WAIT_SECONDS}). "
|
||||
"If not done, returns status=running + sub_session_id — poll via "
|
||||
"get_sub_session_result."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "The task for the sub-AutoPilot to execute.",
|
||||
},
|
||||
"system_context": {
|
||||
"type": "string",
|
||||
"description": "Optional context prepended to the prompt.",
|
||||
"default": "",
|
||||
},
|
||||
"sub_autopilot_session_id": {
|
||||
"type": "string",
|
||||
"description": ("Continue/queue-into a prior sub; empty = new."),
|
||||
"default": "",
|
||||
},
|
||||
"wait_for_result": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Seconds to wait inline. 0 = return immediately. "
|
||||
f"Clamped to {MAX_SUB_SESSION_WAIT_SECONDS}."
|
||||
),
|
||||
"default": 60,
|
||||
},
|
||||
},
|
||||
"required": ["prompt"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
*,
|
||||
prompt: str = "",
|
||||
system_context: str = "",
|
||||
sub_autopilot_session_id: str = "",
|
||||
wait_for_result: int = 60,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
if not prompt.strip():
|
||||
return ErrorResponse(
|
||||
message="prompt is required",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
if user_id is None:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
# Resolve the sub's ChatSession id — either resume an owned one or
|
||||
# create a fresh session that inherits the parent's dry_run so a
|
||||
# sub spawned inside a dry-run conversation doesn't silently
|
||||
# escalate to a live run.
|
||||
sub_session_param = sub_autopilot_session_id.strip()
|
||||
if sub_session_param:
|
||||
owned = await get_chat_session(sub_session_param)
|
||||
if owned is None or owned.user_id != user_id:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"sub_autopilot_session_id {sub_session_param} is not "
|
||||
"a session you own. Leave empty to start a fresh sub, "
|
||||
"or pass a session_id returned by a previous "
|
||||
"run_sub_session call of yours."
|
||||
),
|
||||
session_id=session.session_id,
|
||||
)
|
||||
inner_session_id = sub_session_param
|
||||
else:
|
||||
new_session = await create_chat_session(user_id, dry_run=session.dry_run)
|
||||
inner_session_id = new_session.session_id
|
||||
|
||||
effective_prompt = prompt
|
||||
if system_context.strip():
|
||||
effective_prompt = f"[System Context: {system_context.strip()}]\n\n{prompt}"
|
||||
|
||||
cap = max(0, min(wait_for_result, MAX_SUB_SESSION_WAIT_SECONDS))
|
||||
started_at = time.monotonic()
|
||||
outcome, result = await run_copilot_turn_via_queue(
|
||||
session_id=inner_session_id,
|
||||
user_id=user_id,
|
||||
message=effective_prompt,
|
||||
timeout=cap,
|
||||
permissions=get_current_permissions(),
|
||||
tool_call_id=(f"sub:{session.session_id}" if session.session_id else "sub"),
|
||||
tool_name="run_sub_session",
|
||||
)
|
||||
elapsed = time.monotonic() - started_at
|
||||
return response_from_outcome(
|
||||
outcome=outcome,
|
||||
result=result,
|
||||
inner_session_id=inner_session_id,
|
||||
parent_session_id=session.session_id,
|
||||
elapsed=elapsed,
|
||||
)
|
||||
|
||||
|
||||
def _sub_session_link(inner_session_id: str | None) -> str | None:
|
||||
"""Build the CoPilot UI URL for a sub-AutoPilot session.
|
||||
|
||||
Kept in one place so the format stays consistent across the
|
||||
running/completed/error paths, and so the frontend only has one
|
||||
contract to honour.
|
||||
"""
|
||||
if not inner_session_id:
|
||||
return None
|
||||
return f"/copilot?sessionId={inner_session_id}"
|
||||
|
||||
|
||||
def response_from_outcome(
|
||||
*,
|
||||
outcome: SessionOutcome,
|
||||
result: SessionResult,
|
||||
inner_session_id: str,
|
||||
parent_session_id: str | None,
|
||||
elapsed: float,
|
||||
) -> SubSessionStatusResponse:
|
||||
"""Translate a ``(SessionOutcome, SessionResult)`` tuple into the
|
||||
``SubSessionStatusResponse`` contract the LLM sees.
|
||||
|
||||
``completed`` surfaces the aggregated response text + tool calls.
|
||||
``failed`` returns the error marker with the same handles.
|
||||
``running`` returns just the polling handles so the agent can resume.
|
||||
``queued`` means the target session already had a turn in flight; the
|
||||
message was appended to its pending buffer and will be processed by
|
||||
the existing turn on its next drain.
|
||||
"""
|
||||
link = _sub_session_link(inner_session_id)
|
||||
if outcome == "queued":
|
||||
return SubSessionStatusResponse(
|
||||
message=(
|
||||
f"Target session already had a turn in flight; the message "
|
||||
f"was queued ({result.pending_buffer_length} now pending) and "
|
||||
"will be processed by the existing turn on its next drain. "
|
||||
f"Call get_sub_session_result to poll progress"
|
||||
f"{f' or watch live at {link}' if link else ''}."
|
||||
),
|
||||
session_id=parent_session_id,
|
||||
status="queued",
|
||||
sub_session_id=inner_session_id,
|
||||
sub_autopilot_session_id=inner_session_id,
|
||||
sub_autopilot_session_link=link,
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
)
|
||||
|
||||
if outcome == "running":
|
||||
return SubSessionStatusResponse(
|
||||
message=(
|
||||
f"Sub-AutoPilot is still running after {elapsed:.0f}s."
|
||||
f"{f' Watch live at {link}.' if link else ''} "
|
||||
"Call get_sub_session_result (optionally with "
|
||||
"include_progress=true) to wait, poll, or inspect progress."
|
||||
),
|
||||
session_id=parent_session_id,
|
||||
status="running",
|
||||
sub_session_id=inner_session_id,
|
||||
sub_autopilot_session_id=inner_session_id,
|
||||
sub_autopilot_session_link=link,
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
)
|
||||
|
||||
if outcome == "failed":
|
||||
return SubSessionStatusResponse(
|
||||
message="Sub-AutoPilot failed. See the sub's transcript for details.",
|
||||
session_id=parent_session_id,
|
||||
status="error",
|
||||
sub_session_id=inner_session_id,
|
||||
sub_autopilot_session_id=inner_session_id,
|
||||
sub_autopilot_session_link=link,
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
)
|
||||
|
||||
# completed
|
||||
return SubSessionStatusResponse(
|
||||
message=f"Sub-AutoPilot completed.{f' View at {link}.' if link else ''}",
|
||||
session_id=parent_session_id,
|
||||
status="completed",
|
||||
sub_session_id=inner_session_id,
|
||||
sub_autopilot_session_id=inner_session_id,
|
||||
sub_autopilot_session_link=link,
|
||||
response=result.response_text,
|
||||
tool_calls=[tc.model_dump() for tc in result.tool_calls],
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
)
|
||||
@@ -0,0 +1,523 @@
|
||||
"""Tests for run_sub_session + get_sub_session_result (queue-backed flow).
|
||||
|
||||
Sub-AutoPilots are enqueued on the copilot_execution RabbitMQ queue and
|
||||
executed by any copilot_executor worker. The tools wait for completion
|
||||
by subscribing to ``stream_registry`` for the sub's ChatSession. These
|
||||
tests patch the three integration seams — ``enqueue_copilot_turn``,
|
||||
``wait_for_session_result``, and ``stream_registry.create_session``
|
||||
— to exercise the tool logic without needing RabbitMQ or Redis.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from .get_sub_session_result import GetSubSessionResultTool
|
||||
from .models import ErrorResponse, SubSessionStatusResponse
|
||||
from .run_sub_session import MAX_SUB_SESSION_WAIT_SECONDS, RunSubSessionTool
|
||||
|
||||
|
||||
def _session(user_id: str = "u", session_id: str = "s1") -> MagicMock:
|
||||
sess = MagicMock()
|
||||
sess.session_id = session_id
|
||||
sess.dry_run = False
|
||||
return sess
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queue(monkeypatch):
|
||||
"""Patch the enqueue helpers + the stream-registry session creator at
|
||||
the source modules (session_waiter / get_sub_session_result) so tests
|
||||
don't need RabbitMQ or Redis. Returns a dict of the mocks so
|
||||
individual tests can assert on them.
|
||||
"""
|
||||
enqueue_turn = AsyncMock()
|
||||
enqueue_cancel = AsyncMock()
|
||||
create_session = AsyncMock()
|
||||
|
||||
# run_sub_session calls enqueue_copilot_turn via session_waiter's
|
||||
# run_copilot_turn_via_queue helper — patch at the helper's source.
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.session_waiter.enqueue_copilot_turn",
|
||||
enqueue_turn,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.get_sub_session_result.enqueue_cancel_task",
|
||||
enqueue_cancel,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.session_waiter.stream_registry.create_session",
|
||||
create_session,
|
||||
)
|
||||
return {
|
||||
"enqueue_turn": enqueue_turn,
|
||||
"enqueue_cancel": enqueue_cancel,
|
||||
"create_session": create_session,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_waiter(monkeypatch):
|
||||
"""Patch the queue-backed primitive and the lightweight waiter so
|
||||
tests can drive outcome + result deterministically. Returns the
|
||||
``run_copilot_turn_via_queue`` mock (used by run_sub_session) and
|
||||
the ``wait_for_session_result`` mock (used by get_sub_session_result)
|
||||
wired to return ``("running", SessionResult())`` by default."""
|
||||
from backend.copilot.sdk.session_waiter import SessionResult
|
||||
|
||||
turn_mock = AsyncMock(return_value=("running", SessionResult()))
|
||||
result_mock = AsyncMock(return_value=("running", SessionResult()))
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.run_sub_session.run_copilot_turn_via_queue",
|
||||
turn_mock,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.get_sub_session_result.wait_for_session_result",
|
||||
result_mock,
|
||||
)
|
||||
# Single handle with both attrs for tests that only care about one.
|
||||
turn_mock.result_mock = result_mock
|
||||
return turn_mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model(monkeypatch):
|
||||
"""Patch the model-layer helpers the tools call for session CRUD +
|
||||
ownership checks. The create side returns a fake ChatSession with a
|
||||
fresh uuid each call."""
|
||||
created: list[MagicMock] = []
|
||||
|
||||
async def fake_create(user_id: str, *, dry_run: bool):
|
||||
sess = MagicMock()
|
||||
sess.session_id = f"inner-{len(created) + 1}"
|
||||
sess.user_id = user_id
|
||||
sess.dry_run = dry_run
|
||||
sess.messages = []
|
||||
created.append(sess)
|
||||
return sess
|
||||
|
||||
async def fake_get(session_id: str):
|
||||
for s in created:
|
||||
if s.session_id == session_id:
|
||||
return s
|
||||
return None
|
||||
|
||||
# The tool modules bind these names at import time, so patch the
|
||||
# local module bindings (not the source in backend.copilot.model).
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.run_sub_session.create_chat_session", fake_create
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.run_sub_session.get_chat_session", fake_get
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.get_sub_session_result.get_chat_session", fake_get
|
||||
)
|
||||
return {"created": created, "get": fake_get}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RunSubSessionTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunSubSession:
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_prompt_returns_error(self):
|
||||
r = await RunSubSessionTool()._execute(
|
||||
user_id="u", session=_session(), prompt=""
|
||||
)
|
||||
assert isinstance(r, ErrorResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_user_returns_error(self):
|
||||
r = await RunSubSessionTool()._execute(
|
||||
user_id=None, session=_session(), prompt="hi"
|
||||
)
|
||||
assert isinstance(r, ErrorResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_with_other_users_session_id_rejected(
|
||||
self, monkeypatch, mock_queue, mock_waiter
|
||||
):
|
||||
"""Ownership must be re-verified when the caller passes a resume id."""
|
||||
foreign = MagicMock(session_id="alien-sess", user_id="not-caller", messages=[])
|
||||
|
||||
async def fake_get(session_id: str):
|
||||
if session_id == "alien-sess":
|
||||
return foreign
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.run_sub_session.get_chat_session", fake_get
|
||||
)
|
||||
|
||||
r = await RunSubSessionTool()._execute(
|
||||
user_id="alice",
|
||||
session=_session("alice"),
|
||||
prompt="continue",
|
||||
sub_autopilot_session_id="alien-sess",
|
||||
)
|
||||
assert isinstance(r, ErrorResponse)
|
||||
assert "is not a session you own" in r.message
|
||||
mock_queue["enqueue_turn"].assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_propagates_dry_run_to_sub(self, mock_queue, mock_waiter, mock_model):
|
||||
"""Fresh sub-session must inherit the parent's dry_run flag."""
|
||||
parent = _session("alice")
|
||||
parent.dry_run = True
|
||||
await RunSubSessionTool()._execute(
|
||||
user_id="alice",
|
||||
session=parent,
|
||||
prompt="hi",
|
||||
wait_for_result=0, # skip the wait helper for this assertion
|
||||
)
|
||||
assert mock_model["created"], "create_chat_session was never awaited"
|
||||
assert mock_model["created"][0].dry_run is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forwards_parent_permissions_to_queue(
|
||||
self, monkeypatch, mock_queue, mock_waiter, mock_model
|
||||
):
|
||||
"""The parent's CopilotPermissions must be passed through to the
|
||||
queue primitive so the worker applies the same filter."""
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
perms = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.run_sub_session.get_current_permissions",
|
||||
lambda: perms,
|
||||
)
|
||||
await RunSubSessionTool()._execute(
|
||||
user_id="alice",
|
||||
session=_session("alice"),
|
||||
prompt="hi",
|
||||
wait_for_result=0,
|
||||
)
|
||||
mock_waiter.assert_awaited_once()
|
||||
assert mock_waiter.await_args.kwargs["permissions"] is perms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_result_zero_returns_running(
|
||||
self, mock_queue, mock_waiter, mock_model
|
||||
):
|
||||
"""wait_for_result=0 still dispatches the job (so the sub starts)
|
||||
but the primitive returns 'running' immediately because timeout=0,
|
||||
and the tool surfaces that to the caller."""
|
||||
r = await RunSubSessionTool()._execute(
|
||||
user_id="alice",
|
||||
session=_session("alice"),
|
||||
prompt="hi",
|
||||
wait_for_result=0,
|
||||
)
|
||||
assert isinstance(r, SubSessionStatusResponse)
|
||||
assert r.status == "running"
|
||||
assert r.sub_session_id == r.sub_autopilot_session_id == "inner-1"
|
||||
assert r.sub_autopilot_session_link == "/copilot?sessionId=inner-1"
|
||||
mock_waiter.assert_awaited_once()
|
||||
assert mock_waiter.await_args.kwargs["timeout"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_result_completed_returns_final_response(
|
||||
self, mock_queue, mock_waiter, mock_model
|
||||
):
|
||||
"""When the queue primitive returns 'completed' + a SessionResult,
|
||||
the tool surfaces response_text + tool_calls directly — no DB
|
||||
round-trip needed for the content."""
|
||||
from backend.copilot.sdk.session_waiter import SessionResult
|
||||
from backend.copilot.sdk.stream_accumulator import ToolCallEntry
|
||||
|
||||
res = SessionResult()
|
||||
res.response_text = "the answer"
|
||||
res.tool_calls = [
|
||||
ToolCallEntry(
|
||||
tool_call_id="tc-1",
|
||||
tool_name="foo",
|
||||
input={"x": 1},
|
||||
output="ok",
|
||||
success=True,
|
||||
)
|
||||
]
|
||||
mock_waiter.return_value = ("completed", res)
|
||||
|
||||
r = await RunSubSessionTool()._execute(
|
||||
user_id="alice",
|
||||
session=_session("alice"),
|
||||
prompt="hi",
|
||||
wait_for_result=60,
|
||||
)
|
||||
assert isinstance(r, SubSessionStatusResponse)
|
||||
assert r.status == "completed"
|
||||
assert r.response == "the answer"
|
||||
assert r.tool_calls is not None and len(r.tool_calls) == 1
|
||||
assert r.tool_calls[0]["tool_name"] == "foo"
|
||||
mock_waiter.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queued_outcome_surfaces_queued_status(
|
||||
self, mock_queue, mock_waiter, mock_model
|
||||
):
|
||||
"""When the shared primitive reports the target session already has
|
||||
a turn running, the tool surfaces ``status='queued'`` so the LLM can
|
||||
decide whether to poll or move on."""
|
||||
from backend.copilot.sdk.session_waiter import SessionResult
|
||||
|
||||
queued_res = SessionResult(queued=True, pending_buffer_length=2)
|
||||
mock_waiter.return_value = ("queued", queued_res)
|
||||
|
||||
r = await RunSubSessionTool()._execute(
|
||||
user_id="alice",
|
||||
session=_session("alice"),
|
||||
prompt="please do another thing",
|
||||
wait_for_result=0,
|
||||
)
|
||||
assert isinstance(r, SubSessionStatusResponse)
|
||||
assert r.status == "queued"
|
||||
assert r.sub_session_id == "inner-1"
|
||||
assert "queued" in (r.message or "").lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_clamps_above_maximum(self, mock_queue, mock_waiter, mock_model):
|
||||
"""wait_for_result values above the cap are clamped before being
|
||||
passed to the queue primitive."""
|
||||
await RunSubSessionTool()._execute(
|
||||
user_id="alice",
|
||||
session=_session("alice"),
|
||||
prompt="hi",
|
||||
wait_for_result=MAX_SUB_SESSION_WAIT_SECONDS + 999,
|
||||
)
|
||||
mock_waiter.assert_awaited_once()
|
||||
assert mock_waiter.await_args.kwargs["timeout"] == MAX_SUB_SESSION_WAIT_SECONDS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GetSubSessionResultTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetSubSessionResult:
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_id_returns_error(self):
|
||||
r = await GetSubSessionResultTool()._execute(
|
||||
user_id="u", session=_session(), sub_session_id=""
|
||||
)
|
||||
assert isinstance(r, ErrorResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_id_returns_error(self, monkeypatch):
|
||||
async def none_get(_sid):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.get_sub_session_result.get_chat_session",
|
||||
none_get,
|
||||
)
|
||||
r = await GetSubSessionResultTool()._execute(
|
||||
user_id="u", session=_session(), sub_session_id="missing"
|
||||
)
|
||||
assert isinstance(r, ErrorResponse)
|
||||
assert "No sub-session with id missing" in r.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_other_user_cannot_access(self, monkeypatch):
|
||||
"""Cross-user lookups are indistinguishable from 'not found'."""
|
||||
foreign = MagicMock(user_id="bob", messages=[])
|
||||
|
||||
async def foreign_get(_sid):
|
||||
return foreign
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.get_sub_session_result.get_chat_session",
|
||||
foreign_get,
|
||||
)
|
||||
r = await GetSubSessionResultTool()._execute(
|
||||
user_id="alice", session=_session("alice"), sub_session_id="bobs-sess"
|
||||
)
|
||||
assert isinstance(r, ErrorResponse)
|
||||
assert "No sub-session" in r.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_returns_running(self, monkeypatch, mock_waiter):
|
||||
sub = MagicMock(user_id="alice", messages=[])
|
||||
|
||||
async def fake_get(_sid):
|
||||
return sub
|
||||
|
||||
async def no_active_session(_sid):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.get_sub_session_result.get_chat_session",
|
||||
fake_get,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.get_sub_session_result.stream_registry.get_session",
|
||||
no_active_session,
|
||||
)
|
||||
|
||||
r = await GetSubSessionResultTool()._execute(
|
||||
user_id="alice",
|
||||
session=_session("alice"),
|
||||
sub_session_id="inner-7",
|
||||
wait_if_running=30,
|
||||
)
|
||||
assert isinstance(r, SubSessionStatusResponse)
|
||||
assert r.status == "running"
|
||||
assert r.sub_session_id == "inner-7"
|
||||
mock_waiter.result_mock.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_returns_completed_with_response(self, monkeypatch, mock_waiter):
|
||||
"""'completed' outcome surfaces the SessionResult directly."""
|
||||
from backend.copilot.sdk.session_waiter import SessionResult
|
||||
|
||||
sub = MagicMock(user_id="alice", messages=[]) # not terminal-looking
|
||||
|
||||
async def fake_get(_sid):
|
||||
return sub
|
||||
|
||||
async def no_active_session(_sid):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.get_sub_session_result.get_chat_session",
|
||||
fake_get,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.get_sub_session_result.stream_registry.get_session",
|
||||
no_active_session,
|
||||
)
|
||||
|
||||
res = SessionResult()
|
||||
res.response_text = "done"
|
||||
mock_waiter.result_mock.return_value = ("completed", res)
|
||||
|
||||
r = await GetSubSessionResultTool()._execute(
|
||||
user_id="alice",
|
||||
session=_session("alice"),
|
||||
sub_session_id="inner-3",
|
||||
wait_if_running=30,
|
||||
)
|
||||
assert isinstance(r, SubSessionStatusResponse)
|
||||
assert r.status == "completed"
|
||||
assert r.response == "done"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_terminal_skips_waiter(self, monkeypatch, mock_waiter):
|
||||
"""If the sub's last message is already terminal AND no turn is
|
||||
in flight, the tool returns 'completed' without ever calling
|
||||
wait_for_session_result — it rebuilds the response from the
|
||||
persisted message instead."""
|
||||
sub = MagicMock(user_id="alice")
|
||||
assistant = MagicMock()
|
||||
assistant.role = "assistant"
|
||||
assistant.content = "already done"
|
||||
assistant.tool_calls = None
|
||||
sub.messages = [assistant]
|
||||
|
||||
async def fake_get(_sid):
|
||||
return sub
|
||||
|
||||
async def no_active_session(_sid):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.get_sub_session_result.get_chat_session",
|
||||
fake_get,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.get_sub_session_result.stream_registry.get_session",
|
||||
no_active_session,
|
||||
)
|
||||
|
||||
r = await GetSubSessionResultTool()._execute(
|
||||
user_id="alice",
|
||||
session=_session("alice"),
|
||||
sub_session_id="inner-9",
|
||||
wait_if_running=30,
|
||||
)
|
||||
assert isinstance(r, SubSessionStatusResponse)
|
||||
assert r.status == "completed"
|
||||
assert r.response == "already done"
|
||||
mock_waiter.result_mock.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_turn_in_flight_does_not_return_stale(
|
||||
self, monkeypatch, mock_waiter
|
||||
):
|
||||
"""Regression for sentry r3105409601: on a resumed session whose
|
||||
stream_registry status is 'running' (new turn is mid-flight) the
|
||||
tool must NOT short-circuit to the prior turn's terminal message.
|
||||
It subscribes to the stream like a normal running-session poll."""
|
||||
# DB state reflects the PREVIOUS turn's terminal assistant message.
|
||||
prior = MagicMock()
|
||||
prior.role = "assistant"
|
||||
prior.content = "OLD stale result"
|
||||
prior.tool_calls = None
|
||||
sub = MagicMock(user_id="alice", messages=[prior])
|
||||
|
||||
async def fake_get(_sid):
|
||||
return sub
|
||||
|
||||
running_meta = MagicMock(status="running")
|
||||
|
||||
async def active_registry(_sid):
|
||||
return running_meta
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.get_sub_session_result.get_chat_session",
|
||||
fake_get,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.get_sub_session_result.stream_registry.get_session",
|
||||
active_registry,
|
||||
)
|
||||
|
||||
r = await GetSubSessionResultTool()._execute(
|
||||
user_id="alice",
|
||||
session=_session("alice"),
|
||||
sub_session_id="inner-11",
|
||||
wait_if_running=30,
|
||||
)
|
||||
# The waiter must have been awaited — stale short-circuit was skipped.
|
||||
mock_waiter.result_mock.assert_awaited_once()
|
||||
assert isinstance(r, SubSessionStatusResponse)
|
||||
# Default mock_waiter.result_mock.return_value = ("running", SessionResult())
|
||||
assert r.status == "running"
|
||||
# And crucially NOT the stale content.
|
||||
assert r.response is None or r.response == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_publishes_cancel_event(
|
||||
self, monkeypatch, mock_queue, mock_waiter
|
||||
):
|
||||
"""cancel=true fans out a CancelCoPilotEvent and returns 'cancelled'
|
||||
without waiting for the sub to finish (the worker will finalise)."""
|
||||
sub = MagicMock(user_id="alice", messages=[])
|
||||
|
||||
async def fake_get(_sid):
|
||||
return sub
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.get_sub_session_result.get_chat_session",
|
||||
fake_get,
|
||||
)
|
||||
|
||||
r = await GetSubSessionResultTool()._execute(
|
||||
user_id="alice",
|
||||
session=_session("alice"),
|
||||
sub_session_id="inner-5",
|
||||
cancel=True,
|
||||
)
|
||||
assert isinstance(r, SubSessionStatusResponse)
|
||||
assert r.status == "cancelled"
|
||||
mock_queue["enqueue_cancel"].assert_awaited_once_with("inner-5")
|
||||
mock_waiter.result_mock.assert_not_awaited()
|
||||
@@ -754,15 +754,15 @@ async def test_run_agent_session_dry_run_overrides_kwargs():
|
||||
captured_params["dry_run"] = params.dry_run
|
||||
return {}, None
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_agent.fetch_graph_from_store_slug",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(graph, None),
|
||||
), patch.object(
|
||||
tool, "_check_prerequisites", side_effect=capture_prerequisites
|
||||
), patch.object(
|
||||
tool, "_run_agent", new_callable=AsyncMock
|
||||
) as mock_run_agent:
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_agent.fetch_graph_from_store_slug",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(graph, None),
|
||||
),
|
||||
patch.object(tool, "_check_prerequisites", side_effect=capture_prerequisites),
|
||||
patch.object(tool, "_run_agent", new_callable=AsyncMock) as mock_run_agent,
|
||||
):
|
||||
mock_run_agent.return_value = MagicMock()
|
||||
|
||||
# Pass dry_run=False in kwargs — session.dry_run=True should win.
|
||||
@@ -796,15 +796,15 @@ async def test_run_agent_session_dry_run_false_allows_scheduling():
|
||||
captured_params["dry_run"] = params.dry_run
|
||||
return {}, None
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_agent.fetch_graph_from_store_slug",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(graph, None),
|
||||
), patch.object(
|
||||
tool, "_check_prerequisites", side_effect=capture_prerequisites
|
||||
), patch.object(
|
||||
tool, "_schedule_agent", new_callable=AsyncMock
|
||||
) as mock_schedule:
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_agent.fetch_graph_from_store_slug",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(graph, None),
|
||||
),
|
||||
patch.object(tool, "_check_prerequisites", side_effect=capture_prerequisites),
|
||||
patch.object(tool, "_schedule_agent", new_callable=AsyncMock) as mock_schedule,
|
||||
):
|
||||
mock_schedule.return_value = MagicMock()
|
||||
|
||||
await tool._execute(
|
||||
@@ -840,15 +840,15 @@ async def test_run_agent_session_dry_run_false_allows_llm_dry_run_true():
|
||||
captured_params["dry_run"] = params.dry_run
|
||||
return {}, None
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_agent.fetch_graph_from_store_slug",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(graph, None),
|
||||
), patch.object(
|
||||
tool, "_check_prerequisites", side_effect=capture_prerequisites
|
||||
), patch.object(
|
||||
tool, "_run_agent", new_callable=AsyncMock
|
||||
) as mock_run_agent:
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_agent.fetch_graph_from_store_slug",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(graph, None),
|
||||
),
|
||||
patch.object(tool, "_check_prerequisites", side_effect=capture_prerequisites),
|
||||
patch.object(tool, "_run_agent", new_callable=AsyncMock) as mock_run_agent,
|
||||
):
|
||||
mock_run_agent.return_value = MagicMock()
|
||||
|
||||
# LLM passes dry_run=True; normal session must NOT override it to False
|
||||
|
||||
@@ -7,6 +7,7 @@ from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_generator.validation import AgentValidator, get_blocks_as_dicts
|
||||
from .base import BaseTool
|
||||
from .helpers import require_guide_read
|
||||
from .models import ErrorResponse, ToolResponseBase, ValidationResultResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -24,7 +25,8 @@ class ValidateAgentGraphTool(BaseTool):
|
||||
return (
|
||||
"Validate agent JSON for correctness: block_ids, links, required fields, "
|
||||
"type compatibility, nested sink notation, prompt brace escaping, "
|
||||
"and AgentExecutorBlock configs. On failure, use fix_agent_graph to auto-fix."
|
||||
"and AgentExecutorBlock configs. On failure, use fix_agent_graph to auto-fix. "
|
||||
"Requires get_agent_building_guide first (refuses otherwise)."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -53,6 +55,10 @@ class ValidateAgentGraphTool(BaseTool):
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
guide_gate = require_guide_read(session, "validate_agent_graph")
|
||||
if guide_gate is not None:
|
||||
return guide_gate
|
||||
|
||||
if not agent_json or not isinstance(agent_json, dict):
|
||||
return ErrorResponse(
|
||||
message="Please provide a valid agent JSON object.",
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""JSONL transcript management for stateless multi-turn resume.
|
||||
|
||||
The Claude Code CLI persists conversations as JSONL files (one JSON object per
|
||||
line). When the SDK's ``Stop`` hook fires we read this file, strip bloat
|
||||
(progress entries, metadata), and upload the result to bucket storage. On the
|
||||
next turn we download the transcript, write it to a temp file, and pass
|
||||
``--resume`` so the CLI can reconstruct the full conversation.
|
||||
line). When the SDK's ``Stop`` hook fires the caller reads this file, strips
|
||||
bloat (progress entries, metadata), and uploads the result to bucket storage.
|
||||
On the next turn the caller downloads the bytes and writes them to disk before
|
||||
passing ``--resume`` so the CLI can reconstruct the full conversation.
|
||||
|
||||
Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
|
||||
filesystem for self-hosted) — no DB column needed.
|
||||
@@ -20,6 +20,7 @@ import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from backend.util import json
|
||||
@@ -27,6 +28,9 @@ from backend.util.clients import get_openai_client
|
||||
from backend.util.prompt import CompressResult, compress_context
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage, get_workspace_storage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import ChatMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
|
||||
@@ -44,17 +48,17 @@ STRIPPABLE_TYPES = frozenset(
|
||||
)
|
||||
|
||||
|
||||
TranscriptMode = Literal["sdk", "baseline"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptDownload:
|
||||
"""Result of downloading a transcript with its metadata."""
|
||||
|
||||
content: str
|
||||
message_count: int = 0 # session.messages length when uploaded
|
||||
uploaded_at: float = 0.0 # epoch timestamp of upload
|
||||
content: bytes | str
|
||||
message_count: int = 0
|
||||
# "sdk" = Claude CLI native, "baseline" = TranscriptBuilder
|
||||
mode: TranscriptMode = "sdk"
|
||||
|
||||
|
||||
# Workspace storage constants — deterministic path from session_id.
|
||||
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
|
||||
# Storage prefix for the CLI's native session JSONL files (for cross-pod --resume).
|
||||
_CLI_SESSION_STORAGE_PREFIX = "cli-sessions"
|
||||
|
||||
@@ -363,7 +367,7 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||
|
||||
|
||||
def _projects_base() -> str:
|
||||
def projects_base() -> str:
|
||||
"""Return the resolved path to the CLI's projects directory."""
|
||||
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
return os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
@@ -390,8 +394,8 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
|
||||
|
||||
Returns the number of directories removed.
|
||||
"""
|
||||
projects_base = _projects_base()
|
||||
if not os.path.isdir(projects_base):
|
||||
_pbase = projects_base()
|
||||
if not os.path.isdir(_pbase):
|
||||
return 0
|
||||
|
||||
now = time.time()
|
||||
@@ -399,7 +403,7 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
|
||||
|
||||
# Scoped mode: only clean up the one directory for the current session.
|
||||
if encoded_cwd:
|
||||
target = Path(projects_base) / encoded_cwd
|
||||
target = Path(_pbase) / encoded_cwd
|
||||
if not target.is_dir():
|
||||
return 0
|
||||
# Guard: only sweep copilot-generated dirs.
|
||||
@@ -437,7 +441,7 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
|
||||
# Only safe for single-tenant deployments; callers should prefer the
|
||||
# scoped variant by passing encoded_cwd.
|
||||
try:
|
||||
entries = Path(projects_base).iterdir()
|
||||
entries = Path(_pbase).iterdir()
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to list projects dir: %s", e)
|
||||
return 0
|
||||
@@ -490,9 +494,9 @@ def read_compacted_entries(transcript_path: str) -> list[dict] | None:
|
||||
if not transcript_path:
|
||||
return None
|
||||
|
||||
projects_base = _projects_base()
|
||||
_pbase = projects_base()
|
||||
real_path = os.path.realpath(transcript_path)
|
||||
if not real_path.startswith(projects_base + os.sep):
|
||||
if not real_path.startswith(_pbase + os.sep):
|
||||
logger.warning(
|
||||
"[Transcript] transcript_path outside projects base: %s", transcript_path
|
||||
)
|
||||
@@ -611,28 +615,6 @@ def validate_transcript(content: str | None) -> bool:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
"""Return (workspace_id, file_id, filename) for a session's transcript.
|
||||
|
||||
Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl``
|
||||
IDs are sanitized to hex+hyphen to prevent path traversal.
|
||||
"""
|
||||
return (
|
||||
TRANSCRIPT_STORAGE_PREFIX,
|
||||
_sanitize_id(user_id),
|
||||
f"{_sanitize_id(session_id)}.jsonl",
|
||||
)
|
||||
|
||||
|
||||
def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
"""Return (workspace_id, file_id, filename) for a session's transcript metadata."""
|
||||
return (
|
||||
TRANSCRIPT_STORAGE_PREFIX,
|
||||
_sanitize_id(user_id),
|
||||
f"{_sanitize_id(session_id)}.meta.json",
|
||||
)
|
||||
|
||||
|
||||
def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
|
||||
"""Build a full storage path from (workspace_id, file_id, filename) parts."""
|
||||
wid, fid, fname = parts
|
||||
@@ -642,24 +624,12 @@ def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
|
||||
return f"local://{wid}/{fid}/{fname}"
|
||||
|
||||
|
||||
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||
"""Build the full storage path string that ``retrieve()`` expects."""
|
||||
return _build_path_from_parts(_storage_path_parts(user_id, session_id), backend)
|
||||
|
||||
|
||||
def _build_meta_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||
"""Build the full storage path for the companion .meta.json file."""
|
||||
return _build_path_from_parts(
|
||||
_meta_storage_path_parts(user_id, session_id), backend
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI native session file — cross-pod --resume support
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _cli_session_path(sdk_cwd: str, session_id: str) -> str:
|
||||
def cli_session_path(sdk_cwd: str, session_id: str) -> str:
|
||||
"""Expected path of the CLI's native session JSONL file.
|
||||
|
||||
The CLI resolves the working directory via ``os.path.realpath``, then
|
||||
@@ -675,7 +645,7 @@ def _cli_session_path(sdk_cwd: str, session_id: str) -> str:
|
||||
"""
|
||||
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
safe_id = _sanitize_id(session_id)
|
||||
return os.path.join(_projects_base(), encoded_cwd, f"{safe_id}.jsonl")
|
||||
return os.path.join(projects_base(), encoded_cwd, f"{safe_id}.jsonl")
|
||||
|
||||
|
||||
def _cli_session_storage_path_parts(
|
||||
@@ -689,235 +659,82 @@ def _cli_session_storage_path_parts(
|
||||
)
|
||||
|
||||
|
||||
async def upload_cli_session(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
sdk_cwd: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> None:
|
||||
"""Upload the CLI's native session JSONL file to remote storage.
|
||||
|
||||
Called after each turn so the next turn can restore the file on any pod
|
||||
(eliminating the pod-affinity requirement for --resume).
|
||||
|
||||
The CLI only writes the session file after the turn completes, so this
|
||||
must run in the finally block, AFTER the SDK stream has finished.
|
||||
"""
|
||||
session_file = _cli_session_path(sdk_cwd, session_id)
|
||||
real_path = os.path.realpath(session_file)
|
||||
projects_base = _projects_base()
|
||||
|
||||
if not real_path.startswith(projects_base + os.sep):
|
||||
logger.warning(
|
||||
"%s CLI session file outside projects base, skipping upload: %s",
|
||||
log_prefix,
|
||||
os.path.basename(real_path),
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
raw_bytes = Path(real_path).read_bytes()
|
||||
except FileNotFoundError:
|
||||
logger.debug(
|
||||
"%s CLI session file not found, skipping upload: %s",
|
||||
log_prefix,
|
||||
session_file,
|
||||
)
|
||||
return
|
||||
except OSError as e:
|
||||
logger.warning("%s Failed to read CLI session file: %s", log_prefix, e)
|
||||
return
|
||||
|
||||
# Strip stale thinking blocks and metadata entries (progress, file-history-snapshot,
|
||||
# queue-operation) from the CLI session before writing it back locally and uploading
|
||||
# to GCS. Thinking blocks from non-last assistant turns are not needed for --resume
|
||||
# but can be massive (tens of thousands of tokens each), causing the CLI to auto-compact
|
||||
# its session when the context window fills up. Stripping keeps the session well below
|
||||
# the ~200K-token compaction threshold and prevents silent context loss.
|
||||
try:
|
||||
raw_text = raw_bytes.decode("utf-8")
|
||||
stripped_text = strip_for_upload(raw_text)
|
||||
stripped_bytes = stripped_text.encode("utf-8")
|
||||
if len(stripped_bytes) < len(raw_bytes):
|
||||
# Write the stripped version back locally so same-pod turns also benefit.
|
||||
Path(real_path).write_bytes(stripped_bytes)
|
||||
logger.info(
|
||||
"%s Stripped CLI session file: %dB → %dB",
|
||||
log_prefix,
|
||||
len(raw_bytes),
|
||||
len(stripped_bytes),
|
||||
)
|
||||
content = stripped_bytes
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"%s Failed to strip CLI session file, uploading raw: %s", log_prefix, e
|
||||
)
|
||||
content = raw_bytes
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id)
|
||||
try:
|
||||
await storage.store(
|
||||
workspace_id=wid, file_id=fid, filename=fname, content=content
|
||||
)
|
||||
logger.info(
|
||||
"%s Uploaded CLI session file (%dB) for cross-pod --resume",
|
||||
log_prefix,
|
||||
len(content),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("%s Failed to upload CLI session file: %s", log_prefix, e)
|
||||
|
||||
|
||||
async def restore_cli_session(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
sdk_cwd: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> bool:
|
||||
"""Download and restore the CLI's native session file for --resume.
|
||||
|
||||
Returns True if the file was successfully restored and --resume can be
|
||||
used with the session UUID. Returns False if not available (first turn
|
||||
or upload failed), in which case the caller should not set --resume.
|
||||
"""
|
||||
session_file = _cli_session_path(sdk_cwd, session_id)
|
||||
real_path = os.path.realpath(session_file)
|
||||
projects_base = _projects_base()
|
||||
|
||||
if not real_path.startswith(projects_base + os.sep):
|
||||
logger.warning(
|
||||
"%s CLI session restore path outside projects base: %s",
|
||||
log_prefix,
|
||||
os.path.basename(session_file),
|
||||
)
|
||||
return False
|
||||
|
||||
# If the session file already exists locally (same-pod reuse), use it directly.
|
||||
# Downloading from storage could overwrite a newer local version when a previous
|
||||
# turn's upload failed: stored content is stale while the local file already
|
||||
# contains extended history from that turn.
|
||||
if Path(real_path).exists():
|
||||
logger.debug(
|
||||
"%s CLI session file already exists locally — using it for --resume",
|
||||
log_prefix,
|
||||
)
|
||||
return True
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_path_from_parts(
|
||||
_cli_session_storage_path_parts(user_id, session_id), storage
|
||||
def _cli_session_meta_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
"""Return (workspace_id, file_id, filename) for the CLI session meta file."""
|
||||
return (
|
||||
_CLI_SESSION_STORAGE_PREFIX,
|
||||
_sanitize_id(user_id),
|
||||
f"{_sanitize_id(session_id)}.meta.json",
|
||||
)
|
||||
|
||||
try:
|
||||
content = await storage.retrieve(path)
|
||||
except FileNotFoundError:
|
||||
logger.debug("%s No CLI session in storage (first turn or missing)", log_prefix)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("%s Failed to download CLI session: %s", log_prefix, e)
|
||||
return False
|
||||
|
||||
try:
|
||||
os.makedirs(os.path.dirname(real_path), exist_ok=True)
|
||||
Path(real_path).write_bytes(content)
|
||||
logger.info(
|
||||
"%s Restored CLI session file (%dB) for --resume",
|
||||
log_prefix,
|
||||
len(content),
|
||||
)
|
||||
return True
|
||||
except OSError as e:
|
||||
logger.warning("%s Failed to write CLI session file: %s", log_prefix, e)
|
||||
return False
|
||||
|
||||
|
||||
async def upload_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
content: str,
|
||||
content: bytes,
|
||||
message_count: int = 0,
|
||||
mode: TranscriptMode = "sdk",
|
||||
log_prefix: str = "[Transcript]",
|
||||
skip_strip: bool = False,
|
||||
) -> None:
|
||||
"""Strip progress entries and stale thinking blocks, then upload transcript.
|
||||
"""Upload CLI session content to GCS with companion meta.json.
|
||||
|
||||
The transcript represents the FULL active context (atomic).
|
||||
Each upload REPLACES the previous transcript entirely.
|
||||
Pure GCS operation — no disk I/O. The caller is responsible for reading
|
||||
the session file from disk before calling this function.
|
||||
|
||||
The executor holds a cluster lock per session, so concurrent uploads for
|
||||
the same session cannot happen.
|
||||
Also uploads a companion .meta.json with the message_count watermark so
|
||||
download_transcript can return it without a separate fetch.
|
||||
|
||||
Args:
|
||||
content: Complete JSONL transcript (from TranscriptBuilder).
|
||||
message_count: ``len(session.messages)`` at upload time.
|
||||
skip_strip: When ``True``, skip the strip + re-validate pass.
|
||||
Safe for builder-generated content (baseline path) which
|
||||
never emits progress entries or stale thinking blocks.
|
||||
Called after each turn so the next turn can restore the file on any pod
|
||||
(eliminating the pod-affinity requirement for --resume).
|
||||
"""
|
||||
if skip_strip:
|
||||
# Caller guarantees the content is already clean and valid.
|
||||
stripped = content
|
||||
else:
|
||||
# Strip metadata entries and stale thinking blocks in a single parse.
|
||||
# SDK-built transcripts may have progress entries; strip for safety.
|
||||
stripped = strip_for_upload(content)
|
||||
if not skip_strip and not validate_transcript(stripped):
|
||||
# Log entry types for debugging — helps identify why validation failed
|
||||
entry_types = [
|
||||
json.loads(line, fallback={"type": "INVALID_JSON"}).get("type", "?")
|
||||
for line in stripped.strip().split("\n")
|
||||
]
|
||||
logger.warning(
|
||||
"%s Skipping upload — stripped content not valid "
|
||||
"(types=%s, stripped_len=%d, raw_len=%d)",
|
||||
log_prefix,
|
||||
entry_types,
|
||||
len(stripped),
|
||||
len(content),
|
||||
)
|
||||
logger.debug("%s Raw content preview: %s", log_prefix, content[:500])
|
||||
logger.debug("%s Stripped content: %s", log_prefix, stripped[:500])
|
||||
return
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
wid, fid, fname = _storage_path_parts(user_id, session_id)
|
||||
encoded = stripped.encode("utf-8")
|
||||
meta = {"message_count": message_count, "uploaded_at": time.time()}
|
||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||
wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id)
|
||||
mwid, mfid, mfname = _cli_session_meta_path_parts(user_id, session_id)
|
||||
meta = {"message_count": message_count, "mode": mode, "uploaded_at": time.time()}
|
||||
meta_encoded = json.dumps(meta).encode("utf-8")
|
||||
|
||||
# Transcript + metadata are independent objects at different keys, so
|
||||
# write them concurrently. ``return_exceptions`` keeps a metadata
|
||||
# failure from sinking the transcript write.
|
||||
transcript_result, metadata_result = await asyncio.gather(
|
||||
storage.store(
|
||||
workspace_id=wid,
|
||||
file_id=fid,
|
||||
filename=fname,
|
||||
content=encoded,
|
||||
),
|
||||
storage.store(
|
||||
workspace_id=mwid,
|
||||
file_id=mfid,
|
||||
filename=mfname,
|
||||
content=meta_encoded,
|
||||
),
|
||||
return_exceptions=True,
|
||||
)
|
||||
if isinstance(transcript_result, BaseException):
|
||||
raise transcript_result
|
||||
if isinstance(metadata_result, BaseException):
|
||||
# Metadata is best-effort — the gap-fill logic in
|
||||
# _build_query_message tolerates a missing metadata file.
|
||||
logger.warning("%s Failed to write metadata: %s", log_prefix, metadata_result)
|
||||
# Write JSONL first, meta second — sequential so a crash between the two
|
||||
# leaves an orphaned JSONL (no meta) rather than an orphaned meta (wrong
|
||||
# watermark / mode paired with stale or absent content).
|
||||
# On any failure we roll back the other file so the pair is always absent
|
||||
# together; download_transcript returns None when either file is missing.
|
||||
try:
|
||||
await storage.store(
|
||||
workspace_id=wid, file_id=fid, filename=fname, content=content
|
||||
)
|
||||
except Exception as session_err:
|
||||
logger.warning(
|
||||
"%s Failed to upload CLI session file: %s", log_prefix, session_err
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
await storage.store(
|
||||
workspace_id=mwid, file_id=mfid, filename=mfname, content=meta_encoded
|
||||
)
|
||||
except Exception as meta_err:
|
||||
logger.warning("%s Failed to upload CLI session meta: %s", log_prefix, meta_err)
|
||||
# Roll back the JSONL so neither file exists — avoids orphaned JSONL being
|
||||
# used with wrong mode/watermark defaults on the next restore.
|
||||
try:
|
||||
session_path = _build_path_from_parts(
|
||||
_cli_session_storage_path_parts(user_id, session_id), storage
|
||||
)
|
||||
await storage.delete(session_path)
|
||||
except Exception as rollback_err:
|
||||
logger.debug(
|
||||
"%s Session rollback failed (harmless — download will return None): %s",
|
||||
log_prefix,
|
||||
rollback_err,
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
|
||||
"%s Uploaded CLI session (%dB, msg_count=%d, mode=%s)",
|
||||
log_prefix,
|
||||
len(encoded),
|
||||
len(content),
|
||||
message_count,
|
||||
mode,
|
||||
)
|
||||
|
||||
|
||||
@@ -926,83 +743,181 @@ async def download_transcript(
|
||||
session_id: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> TranscriptDownload | None:
|
||||
"""Download transcript and metadata from bucket storage.
|
||||
"""Download CLI session from GCS. Returns content + message_count + mode, or None if not found.
|
||||
|
||||
Returns a ``TranscriptDownload`` with the JSONL content and the
|
||||
``message_count`` watermark from the upload, or ``None`` if not found.
|
||||
Pure GCS operation — no disk I/O. The caller is responsible for writing
|
||||
content to disk if --resume is needed.
|
||||
|
||||
The content and metadata fetches run concurrently since they are
|
||||
independent objects in the bucket.
|
||||
Returns a TranscriptDownload with the raw content, message_count watermark,
|
||||
and mode on success, or None if not available (first turn or upload failed).
|
||||
"""
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
meta_path = _build_meta_storage_path(user_id, session_id, storage)
|
||||
path = _build_path_from_parts(
|
||||
_cli_session_storage_path_parts(user_id, session_id), storage
|
||||
)
|
||||
meta_path = _build_path_from_parts(
|
||||
_cli_session_meta_path_parts(user_id, session_id), storage
|
||||
)
|
||||
|
||||
content_task = asyncio.create_task(storage.retrieve(path))
|
||||
meta_task = asyncio.create_task(storage.retrieve(meta_path))
|
||||
content_result, meta_result = await asyncio.gather(
|
||||
content_task, meta_task, return_exceptions=True
|
||||
storage.retrieve(path),
|
||||
storage.retrieve(meta_path),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
if isinstance(content_result, FileNotFoundError):
|
||||
logger.debug("%s No transcript in storage", log_prefix)
|
||||
logger.debug("%s No CLI session in storage (first turn or missing)", log_prefix)
|
||||
return None
|
||||
if isinstance(content_result, BaseException):
|
||||
logger.warning(
|
||||
"%s Failed to download transcript: %s", log_prefix, content_result
|
||||
"%s Failed to download CLI session: %s", log_prefix, content_result
|
||||
)
|
||||
return None
|
||||
|
||||
content = content_result.decode("utf-8")
|
||||
content: bytes = content_result
|
||||
|
||||
# Metadata is best-effort — old transcripts won't have it.
|
||||
# Parse message_count and mode from companion meta — best-effort, defaults.
|
||||
message_count = 0
|
||||
uploaded_at = 0.0
|
||||
mode: TranscriptMode = "sdk"
|
||||
if isinstance(meta_result, FileNotFoundError):
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
pass # No meta — old upload; default to "sdk"
|
||||
elif isinstance(meta_result, BaseException):
|
||||
logger.debug(
|
||||
"%s Failed to load transcript metadata: %s", log_prefix, meta_result
|
||||
)
|
||||
logger.debug("%s Failed to load CLI session meta: %s", log_prefix, meta_result)
|
||||
else:
|
||||
meta = json.loads(meta_result.decode("utf-8"), fallback={})
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
try:
|
||||
meta_str = meta_result.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
logger.debug("%s CLI session meta is not valid UTF-8, ignoring", log_prefix)
|
||||
meta_str = None
|
||||
if meta_str is not None:
|
||||
meta = json.loads(meta_str, fallback={})
|
||||
if isinstance(meta, dict):
|
||||
raw_count = meta.get("message_count", 0)
|
||||
message_count = (
|
||||
raw_count if isinstance(raw_count, int) and raw_count >= 0 else 0
|
||||
)
|
||||
raw_mode = meta.get("mode", "sdk")
|
||||
mode = raw_mode if raw_mode in ("sdk", "baseline") else "sdk"
|
||||
|
||||
logger.info(
|
||||
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
|
||||
)
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
uploaded_at=uploaded_at,
|
||||
"%s Downloaded CLI session (%dB, msg_count=%d, mode=%s)",
|
||||
log_prefix,
|
||||
len(content),
|
||||
message_count,
|
||||
mode,
|
||||
)
|
||||
return TranscriptDownload(content=content, message_count=message_count, mode=mode)
|
||||
|
||||
|
||||
def detect_gap(
|
||||
download: TranscriptDownload,
|
||||
session_messages: list[ChatMessage],
|
||||
) -> list[ChatMessage]:
|
||||
"""Return chat-db messages after the transcript watermark (excluding current user turn).
|
||||
|
||||
Returns [] if transcript is current, watermark is zero, or the watermark
|
||||
position doesn't end on an assistant turn (misaligned watermark).
|
||||
"""
|
||||
if download.message_count == 0:
|
||||
return []
|
||||
wm = download.message_count
|
||||
total = len(session_messages)
|
||||
if wm >= total - 1:
|
||||
return []
|
||||
# Sanity: position wm-1 should be an assistant turn; misaligned watermark
|
||||
# means the DB messages shifted (e.g. deletion) — skip gap to avoid wrong context.
|
||||
# In normal operation ``message_count`` is always written after a complete
|
||||
# user→assistant exchange (never mid-turn), so the last covered position is
|
||||
# always assistant. This guard fires only on data corruption or message deletion.
|
||||
if session_messages[wm - 1].role != "assistant":
|
||||
return []
|
||||
return list(session_messages[wm : total - 1])
|
||||
|
||||
|
||||
def extract_context_messages(
|
||||
download: TranscriptDownload | None,
|
||||
session_messages: "list[ChatMessage]",
|
||||
) -> "list[ChatMessage]":
|
||||
"""Return context messages for the current turn: transcript content + gap.
|
||||
|
||||
This is the shared context primitive used by both the SDK path
|
||||
(``use_resume=False`` → ``<conversation_history>`` injection) and the
|
||||
baseline path (OpenAI messages array).
|
||||
|
||||
How it works:
|
||||
|
||||
- When a transcript exists, ``TranscriptBuilder.load_previous`` preserves
|
||||
``isCompactSummary=True`` compaction entries, so the returned messages
|
||||
mirror the compacted context the CLI would see via ``--resume``.
|
||||
- The gap (DB messages after the transcript watermark) is always small in
|
||||
normal operation; it only grows during mode switches or when an upload
|
||||
was missed.
|
||||
- Falls back to full DB messages when no transcript exists (first turn,
|
||||
upload failure, or GCS unavailable).
|
||||
- Returns *prior* messages only (excluding the current user turn at
|
||||
``session_messages[-1]``). Callers that need the current turn append
|
||||
``session_messages[-1]`` themselves.
|
||||
- **Tool calls from transcript entries are flattened to text**: assistant
|
||||
messages derived from the JSONL use ``_flatten_assistant_content``, which
|
||||
serialises ``tool_use`` blocks as human-readable text rather than
|
||||
structured ``tool_calls``. Gap messages (from DB) preserve their
|
||||
original ``tool_calls`` field. This is the same trade-off as the old
|
||||
``_compress_session_messages(session.messages)`` approach — no regression.
|
||||
|
||||
Args:
|
||||
download: The ``TranscriptDownload`` from GCS, or ``None`` when no
|
||||
transcript is available. ``content`` may be either ``bytes`` or
|
||||
``str`` (the baseline path decodes + strips before returning).
|
||||
session_messages: All messages in the session, with the current user
|
||||
turn as the last element.
|
||||
|
||||
Returns:
|
||||
A list of ``ChatMessage`` objects covering the prior conversation
|
||||
context, suitable for injection as conversation history.
|
||||
"""
|
||||
from .model import ChatMessage as _ChatMessage # runtime import
|
||||
|
||||
# ``role="reasoning"`` rows are persisted for frontend replay of
|
||||
# extended_thinking content but are NOT conversation context — the
|
||||
# transcript-based --resume path already carries thinking separately,
|
||||
# and sending them back to the model as user/assistant turns would be
|
||||
# both redundant and malformed. Drop them before any gap detection
|
||||
# or transcript comparison so ordering invariants still hold.
|
||||
session_messages = [m for m in session_messages if m.role != "reasoning"]
|
||||
|
||||
prior = session_messages[:-1]
|
||||
|
||||
if download is None:
|
||||
return prior
|
||||
|
||||
raw_content = download.content
|
||||
if not raw_content:
|
||||
return prior
|
||||
|
||||
# Handle both bytes (raw GCS download) and str (pre-decoded baseline path).
|
||||
if isinstance(raw_content, bytes):
|
||||
try:
|
||||
content_str: str = raw_content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return prior
|
||||
else:
|
||||
content_str = raw_content
|
||||
|
||||
raw = _transcript_to_messages(content_str)
|
||||
if not raw:
|
||||
return prior
|
||||
|
||||
transcript_msgs = [
|
||||
_ChatMessage(role=m["role"], content=m.get("content") or "") for m in raw
|
||||
]
|
||||
gap = detect_gap(download, session_messages)
|
||||
return transcript_msgs + gap
|
||||
|
||||
|
||||
async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
"""Delete transcript and its metadata from bucket storage.
|
||||
|
||||
Removes both the ``.jsonl`` transcript and the companion ``.meta.json``
|
||||
so stale ``message_count`` watermarks cannot corrupt gap-fill logic.
|
||||
"""
|
||||
"""Delete CLI session JSONL and its companion .meta.json from bucket storage."""
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
try:
|
||||
await storage.delete(path)
|
||||
logger.info("[Transcript] Deleted transcript for session %s", session_id)
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete transcript: %s", e)
|
||||
|
||||
# Also delete the companion .meta.json to avoid orphaned metadata.
|
||||
try:
|
||||
meta_path = _build_meta_storage_path(user_id, session_id, storage)
|
||||
await storage.delete(meta_path)
|
||||
logger.info("[Transcript] Deleted metadata for session %s", session_id)
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete metadata: %s", e)
|
||||
|
||||
# Also delete the CLI native session file to prevent storage growth.
|
||||
try:
|
||||
cli_path = _build_path_from_parts(
|
||||
_cli_session_storage_path_parts(user_id, session_id), storage
|
||||
@@ -1012,6 +927,15 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete CLI session: %s", e)
|
||||
|
||||
try:
|
||||
cli_meta_path = _build_path_from_parts(
|
||||
_cli_session_meta_path_parts(user_id, session_id), storage
|
||||
)
|
||||
await storage.delete(cli_meta_path)
|
||||
logger.info("[Transcript] Deleted CLI session meta for session %s", session_id)
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete CLI session meta: %s", e)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transcript compaction — LLM summarization for prompt-too-long recovery
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user