Compare commits

..

288 Commits

Author SHA1 Message Date
majdyz
12601f3ab9 fix(copilot): cap sessionModes at 200 entries to prevent localStorage leak 2026-04-13 12:54:53 +00:00
majdyz
47be9c7024 fix(copilot): default to thinking mode for sessions without recorded mode
Sessions created before the mode fix had no recorded mode. Previously
restoreSessionMode would leave the global mode unchanged (whatever it
was set to on another session). Now defaults to extended_thinking when
no mode is recorded — no need to clear localStorage.
2026-04-13 12:52:01 +00:00
majdyz
c9fadf20e1 fix(copilot): record current session mode before switching away
Old sessions (created before the mode fix) didn't have a recorded
mode, so switching away and back would lose the mode. Now we record
the current mode for the departing session before switching.
2026-04-13 12:48:03 +00:00
majdyz
7d16258a98 perf(copilot): reduce tool output truncation from 500K to 100K chars
500K chars (~125K tokens) per tool result was too generous — a few
large tool outputs could push context past 200K+ tokens. 100K chars
(~25K tokens) keeps individual results reasonable. The SDK writes
oversized results to tool-results/ files and returns a reference.
2026-04-13 12:24:35 +00:00
majdyz
ac054c31f6 perf(copilot): trigger compaction at 100K tokens instead of 140K
Set CLAUDE_AUTOCOMPACT_PCT_OVERRIDE=50 to compact at 50% of 200K
context window (100K) instead of the default 70% (140K). Context
>200K accounts for 54% of cost despite being only 3% of calls.
Earlier compaction keeps context smaller and reduces cache creation.
2026-04-13 12:15:52 +00:00
majdyz
1d3cce0ebf fix(copilot): strip <internal_reasoning> tags from Sonnet response stream
Models without extended thinking (e.g. Sonnet) sometimes emit
<internal_reasoning>...</internal_reasoning> tags as visible text.
Extract ThinkingStripper to a shared module and apply it to the SDK
streaming path so these tags are stripped before reaching the SSE
client and the persisted message.
2026-04-13 11:50:43 +00:00
majdyz
ea1d8485f5 fix: resolve openapi.json merge conflict — keep cost_bearing_request_count 2026-04-13 11:39:01 +00:00
majdyz
364d98aab6 fix(copilot): remove effort=low default to prevent internal_reasoning leak
effort=low on Sonnet causes <internal_reasoning> tags to leak into
visible output. Changed default to None (let model decide). Only
passed to SDK when explicitly set via CHAT_CLAUDE_AGENT_THINKING_EFFORT.
2026-04-13 11:36:16 +00:00
majdyz
f121dcd5c8 Resolve merge conflicts in copilot baseline service files
Keep HEAD's pre-drain count logic for transcript loading and drain error
handling, and merge incoming cache token extraction tests from PR #12762.
2026-04-13 10:49:02 +00:00
majdyz
ea0b5f70ad Fix merge conflict in platform_cost.py crashing all new pods
Resolve conflicts between cost dashboard PR (#12757) and cache token
columns PR (#12762). Keep all HEAD-side functionality (percentile
queries, histogram buckets, cost-bearing request counts, unfiltered
aggregate) while retaining cache token fields from the incoming side.
2026-04-13 10:37:49 +00:00
majdyz
dbaaa88e1b perf(copilot): switch default model from Opus to Sonnet
Opus at $15/$75 per M tokens is unsustainable for agentic sessions
(1M+ context after 30+ turns = $7+/turn). Sonnet at $3/$15 per M
is 5x cheaper with comparable quality for most tasks.

Override via CHAT_MODEL=anthropic/claude-opus-4.6 for premium tier.
2026-04-13 10:25:49 +00:00
majdyz
ec2acfb9e3 fix(frontend): add cache token fields to UserCostSummary in openapi.json
The backend added total_cache_read_tokens and total_cache_creation_tokens
to UserCostSummary but the OpenAPI spec was not updated, causing frontend
build failures.
2026-04-13 10:13:18 +00:00
majdyz
69e9a5bb22 fix(frontend): add cache token fields to UserCostSummary in openapi.json
The backend added total_cache_read_tokens and total_cache_creation_tokens
to UserCostSummary but the OpenAPI spec was not updated, causing frontend
build failures.
2026-04-13 10:12:44 +00:00
majdyz
95087cd170 Merge branch 'fix/copilot-mode-per-session' of https://github.com/Significant-Gravitas/AutoGPT into preview/all-prs 2026-04-13 09:58:49 +00:00
majdyz
1e7eadce26 fix(copilot): validate persisted session modes, add removeSessionMode, fix useEffect deps
- Validate entries from localStorage before constructing the sessionModes map,
  filtering out corrupt/unknown mode strings (addresses CodeRabbit review)
- Add removeSessionMode action and call it on session delete so the map does
  not grow unboundedly
- Add recordSessionMode to the useEffect dependency array to avoid stale-closure risk
- Add clarifying comment to restoreSessionMode no-op branch
- Extend tests to cover removeSessionMode, no-op, and corrupt-localStorage behaviour
2026-04-13 09:57:14 +00:00
majdyz
1485d1910c Merge branch 'fix/sse-replay-deduplication' of https://github.com/Significant-Gravitas/AutoGPT into preview/all-prs 2026-04-13 09:56:12 +00:00
majdyz
89c9c649d8 fix: resolve merge conflicts in UserTable.tsx — keep all columns (avg cost + cache read/write) 2026-04-13 09:55:56 +00:00
majdyz
a17f05f2b1 fix(copilot): scope dedup fingerprint by user message ID instead of text
Using user message text as the context key caused the deduplicator to
drop the second assistant reply when a user asked the same question twice
in one session. Switching to user message ID (which is unique per turn)
fixes the false positive while still preventing SSE-replayed duplicates.

Adds a regression test covering the same-question-twice scenario.
2026-04-13 09:55:54 +00:00
majdyz
62e4a8d3a4 Merge branch 'fix/copilot-mode-per-session' of https://github.com/Significant-Gravitas/AutoGPT into preview/all-prs 2026-04-13 09:54:21 +00:00
majdyz
c6af52033d fix(copilot): fix multi-turn cost over-estimation and add cache_creation_tokens extraction
Bug 1: Fallback cost estimation was using accumulated turn_prompt_tokens /
turn_completion_tokens across all tool-call rounds, causing compounding
over-estimation on the 2nd+ turn. Snapshot token counts before each call and
pass only the per-call delta to _estimate_cost_from_tokens.

Bug 2: turn_cache_creation_tokens was defined but never populated. Extract
cache_creation_input_tokens from prompt_tokens_details (available from some
providers such as Anthropic via OpenRouter).

Add regression tests for both fixes.
2026-04-13 09:53:05 +00:00
majdyz
1df9369dc3 perf(copilot): add effort=low thinking control + raise budget to $15
- Add claude_agent_thinking_effort config (default: 'low') to control
  thinking depth. 'low' minimizes thinking token usage — the #1 cost
  driver at 49% of total spend.
- Raise max_budget_usd from $5 to $15 — $5 was below p50 ($5.37),
  causing half of all turns to get budget-killed mid-task.
- Log raw SDK usage dict to discover thinking token fields.
2026-04-13 09:43:26 +00:00
majdyz
f6c7d1eaf7 fix(copilot): baseline cost tracking fallback and dashboard cache token display
When OpenRouter's x-total-cost header is missing, estimate cost from
token counts using a known model pricing table so cost is always logged.
Also extract cache token details from streaming usage chunks
(prompt_tokens_details.cached_tokens) and pass them through to
PlatformCostLog.

On the dashboard side, add cache read/write columns to the logs table
and user table, and include cache tokens in the UserCostSummary backend
model so they surface in the API response.
2026-04-13 09:39:44 +00:00
majdyz
85f76230a9 debug(copilot): log raw SDK usage dict to discover thinking token fields
Temporary debug logging to see all fields in ResultMessage.usage —
need to confirm if thinking_tokens or similar is available but not
being captured.
2026-04-13 09:35:05 +00:00
majdyz
f63440e955 fix(copilot): store mode per session so indicator updates on switch
The copilot mode (fast/extended_thinking) was stored as a single global
value. When switching between sessions, the mode indicator stayed on
whatever was last set globally rather than reflecting the mode each
session was created with.

Add a sessionModes map to the Zustand store that records the active
copilotMode when a session is created and restores it when the user
switches back to that session.
2026-04-13 09:32:45 +00:00
majdyz
f52c1e1f24 fix(copilot): raise max_budget_usd from $5 to $15
$5 was too aggressive — p50 cost is $5.37 so half of all turns were
getting budget-killed mid-task with no value delivered. $15 covers p75
($13.07) so ~75% of tasks complete. The thinking token cap is the
better cost lever but needs verification first.
2026-04-13 08:47:16 +00:00
majdyz
b5216da2d8 fix(copilot): disable gzip on API responses to prevent ZlibError
Add Accept-Encoding: identity to ANTHROPIC_CUSTOM_HEADERS in
build_sdk_env() to prevent ZlibError decompression failures in the
CLI subprocess. Appended after any existing custom headers (OpenRouter
trace headers).

See: oven-sh/bun#23149, anthropics/claude-code#18302
2026-04-13 08:26:01 +00:00
majdyz
ffa74177d0 fix: add ::timestamptz casts to raw SQL datetime comparisons in _build_raw_where
The raw SQL WHERE clause builder was passing datetime parameters without
explicit type casts, causing PostgreSQL to fail with "operator does not
exist: timestamp without time zone >= text".
2026-04-13 08:23:43 +00:00
majdyz
b6b94a2244 Merge branch 'fix/sse-replay-deduplication' of https://github.com/Significant-Gravitas/AutoGPT into preview/all-prs 2026-04-13 08:06:29 +00:00
majdyz
7cadce4c7b fix(copilot): deduplicate SSE-replayed messages by content fingerprint
When the SSE connection reconnects, resume_session_stream replays from
"0-0" and the replayed UIMessage objects get new IDs from useChat,
bypassing the adjacent-only content dedup. Switch deduplicateMessages
to track all seen role+context+content fingerprints globally, scoped
by the preceding user message to avoid false positives when the
assistant legitimately gives identical answers to different prompts.
2026-04-13 08:04:04 +00:00
majdyz
00a20bdfe6 fix(copilot): deduplicate SSE-replayed messages by content fingerprint
When the SSE connection reconnects, resume_session_stream replays from
"0-0" and the replayed UIMessage objects get new IDs from useChat,
bypassing the adjacent-only content dedup. Switch deduplicateMessages
to track all seen role+context+content fingerprints globally, scoped
by the preceding user message to avoid false positives when the
assistant legitimately gives identical answers to different prompts.
2026-04-13 08:03:51 +00:00
majdyz
e0ddb7d4d4 Merge branch 'feat/enhanced-cost-dashboard' of https://github.com/Significant-Gravitas/AutoGPT into preview/all-prs
# Conflicts:
#	autogpt_platform/backend/backend/data/platform_cost_test.py
2026-04-13 08:03:15 +00:00
majdyz
d8d0f752b5 Merge branch 'feat/builder-chat-panel' of https://github.com/Significant-Gravitas/AutoGPT into preview/all-prs
# Conflicts:
#	autogpt_platform/backend/backend/data/platform_cost_test.py
2026-04-13 08:02:58 +00:00
majdyz
c64d5a9c92 Merge branch 'perf/copilot-prompt-caching' of https://github.com/Significant-Gravitas/AutoGPT into preview/all-prs 2026-04-13 08:02:37 +00:00
majdyz
f8bca6f4bc Merge commit '2cf737dc0508a7753d067ed8425cfc0ef657b29f' into preview/all-prs
# Conflicts:
#	autogpt_platform/backend/backend/copilot/config.py
2026-04-13 08:02:31 +00:00
majdyz
6c21e58d31 Merge branch 'fix/orchestrator-per-iteration-cost' of https://github.com/Significant-Gravitas/AutoGPT into preview/all-prs 2026-04-13 08:01:50 +00:00
majdyz
895c9a0d29 Merge branch 'feat/copilot-pending-messages' of https://github.com/Significant-Gravitas/AutoGPT into preview/all-prs 2026-04-13 08:01:45 +00:00
majdyz
84e877e36d Merge branch 'fix/schedule-agent-cred-setup-ux' of https://github.com/Significant-Gravitas/AutoGPT into preview/all-prs 2026-04-13 08:01:40 +00:00
majdyz
a504ad6e1e Merge branch 'chore/sdk-dev-preview-0.1.58-with-proxy' of https://github.com/Significant-Gravitas/AutoGPT into preview/all-prs 2026-04-13 08:01:33 +00:00
majdyz
ca0c95b593 fix(frontend): add SUBSCRIPTION to CreditTransactionType enum in openapi.json
Syncs the OpenAPI spec with the Prisma schema which already includes the
SUBSCRIPTION enum value in CreditTransactionType.
2026-04-13 07:13:21 +00:00
majdyz
cbf309c9e4 Merge branch 'dev' of https://github.com/Significant-Gravitas/AutoGPT into feat/copilot-pending-messages 2026-04-13 07:12:49 +00:00
majdyz
6ccb44e0d5 fix(copilot): add 404/429 to route decorator, reformat routes.py, regenerate openapi.json
Add responses={404, 429} to the pending endpoint's @router.post decorator
so FastAPI auto-generates them in the OpenAPI spec. Previously these were
only manually added to openapi.json and the CI schema-check (export +
diff) stripped them. Also apply black formatting to the long warning
line that was failing the backend lint check.
2026-04-13 07:04:07 +00:00
majdyz
e558c60104 fix(orchestrator): don't propagate non-billing charge errors as tool failures
Non-IBE exceptions from charge_node_usage (e.g. DB timeout) were
re-raised and caught by the outer generic handler, incorrectly marking
a successful tool execution as failed. This could cause the LLM to
retry side-effectful operations. Now logs the error and continues to
the success path since the tool itself completed successfully.
2026-04-13 07:02:10 +00:00
majdyz
fbad856538 fix(backend/copilot): relax schedule race test assertion for setup_test_data fixture
The setup_test_data fixture creates a graph with credentials already
embedded in node defaults. The DB-stored credential schema may not
surface these as "missing" in build_missing_credentials_from_graph,
so assert the key exists rather than asserting non-empty count.
2026-04-13 06:59:18 +00:00
majdyz
3ebfa3d68b fix(backend/copilot): address round-6 review — DRY validation handler, improve tests
- Extract duplicated GraphValidationError handler from _run_agent and
  _schedule_agent into _handle_graph_validation_race helper method
- Use generator expressions instead of list comprehension for
  short-circuit evaluation in _build_setup_requirements_from_validation_error
- Improve mixed-error fallback message to be more user-friendly
- Add test for empty node_errors={} edge case
- Pin expected credential count in firecrawl fixture tests
- Add missing_credentials assertion to schedule race E2E test
- Add test for extras present with node_errors=None in service_test
2026-04-13 06:45:28 +00:00
majdyz
5ff46ff207 fix(backend): address review feedback on orchestrator billing
- Extract post-execution billing into _handle_post_execution_billing()
- Deduplicate IBE notification into _try_send_insufficient_funds_notif()
- Combine _charge_usage + _handle_low_balance into single thread dispatch
- Sanitize error messages to LLM (no internal details leaked)
- Default _is_error to True (fail-closed) for tool responses
- Add IBE propagation contract to OrchestratorBlock class docstring
- Reduce per-site IBE comments to one-liners referencing class docstring
- Fix _resolve_block_cost return type annotation (Block | None)
- Move test imports to module level, fix test_default_block_returns_zero
- Add tests for non-IBE billing failure and _charge_usage(count=0)
- Fix Black formatting (CI lint blocker)
2026-04-13 06:44:20 +00:00
majdyz
2cf737dc05 fix(backend): address review comments on cross-user prompt caching PR
- Add TODO(#12747) to _SystemPromptPreset for cleanup tracking
- Update docstring to note SDK version and migration path
- Add debug logging in _build_system_prompt_value for observability
- Document empty-string edge case in docstring
- Trim redundant block comment at call site to single line
- Add test for empty-string system_prompt with cache enabled
- Add test for CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE=false env var
2026-04-13 06:43:57 +00:00
majdyz
040637dd68 fix: force cost_usd for percentile/histogram queries, fix test + prettier
- Backend: always pass tracking_type=None to _build_raw_where for
  percentile and histogram queries so they compute stats on cost_usd
  rows regardless of the caller's tracking_type filter.
- Frontend test: use getAllByText for "5" which appears in both the
  Active Users card and the $1-2 bucket count.
- Frontend: fix prettier formatting in PlatformCostContent.tsx.
2026-04-13 06:36:59 +00:00
majdyz
90d8ae0ae2 fix(copilot): map non-E2B file tools in permissions and fix lint formatting
In non-E2B mode, to_sdk_names() failed to map whitelisted SDK built-in
file tool names (Write, Edit, Read) to their MCP-prefixed equivalents
(mcp__copilot__Write, etc.), causing them to be incorrectly filtered out
when users configured tool whitelists.

Add _SDK_TO_MCP mapping for non-E2B mode that maps Read->read_file,
Write->Write, Edit->Edit. Add test coverage for this case.

Also fix black formatting in permissions_test.py that was causing CI lint
failure.
2026-04-13 06:34:55 +00:00
majdyz
967f0c97c4 fix(copilot): fix black formatting for single-line ValueError raise 2026-04-13 06:29:25 +00:00
majdyz
7dc4319125 fix: correct group_by count in test_passes_filters_to_queries
The 6th group_by (total agg no-tracking-type) only runs when
tracking_type is set. This test doesn't pass tracking_type, so the
expected count is 5, not 6.
2026-04-13 05:28:12 +00:00
majdyz
a8cfe27f6b fix: use real temp files in CLI path env var tests
The path validator rejects non-existent paths, so tests must create
real executable temp files via tmp_path instead of hardcoded paths.
2026-04-13 05:28:08 +00:00
majdyz
4cc8ef4409 fix(platform-cost): address PR review — deduplicate filter logic, skip redundant query, improve frontend
Backend:
- Extract _build_raw_where() helper so raw SQL and Prisma WHERE share
  filter logic (review item #4 — duplicated filter logic)
- Skip redundant total_agg_no_tracking_type_groups query when
  tracking_type is None since it duplicates total_agg_groups (item #3)
- Convert CostBucket from TypedDict to BaseModel for consistency (nit #1)
- Replace fragile 8-way positional tuple unpack with indexed list access

Frontend:
- Make 12 SummaryCards data-driven via a cards config array (item #5)
- Use friendlier percentile labels: Typical/Upper/High/Peak Cost (P50/P75/P95/P99)
- Update test fixtures with all new dashboard fields (item #1)
- Add test assertions for new summary card labels, cost buckets, token
  values, and user table columns
2026-04-13 05:16:55 +00:00
majdyz
359b7f1b81 fix(copilot): address PR reviewer feedback on CLI path validation and defaults
- Reject non-existent and non-file CLI paths at config validation time
  instead of letting them fail with opaque OS errors at runtime
- Add negative test coverage for CLI path validator (non-existent,
  non-executable, directory paths)
- Document breaking default changes (max_turns 1000->50, max_budget
  $100->$5) in field descriptions with env var override instructions
- Narrow broad `except Exception` to `except (ImportError, AttributeError)`
  in cli_openrouter_compat_test.py
2026-04-13 05:13:56 +00:00
Zamil Majdy
a3b0cea942 fix(frontend/builder): route text parts through MessagePartRenderer
Text parts in assistant messages were being rendered as plain <span>
elements, bypassing MessagePartRenderer's case "text" handler and
parseSpecialMarkers(). This broke styled error/system messages
([ERROR:], [RETRYABLE_ERROR:], [SYSTEM:] markers) and markdown
rendering in the builder chat panel.

Route all assistant message parts (text and tool) through
MessagePartRenderer so parseSpecialMarkers() runs on text content.
2026-04-13 04:42:18 +00:00
majdyz
ae1600a99d fix(copilot): rename SDK read_tool_result tool and fix path leak in error message
- Rename `_READ_TOOL_NAME` from `"Read"` to `"read_tool_result"` so the LLM
  can distinguish it from `read_file` (working-directory tool).  The new name
  plus an updated description make its narrow scope (tool-results/ paths and
  workspace:// URIs) unambiguous.
- Fix path leak in `_read_file_handler`: use `os.path.basename(file_path)` in
  the "Path not allowed" error, consistent with write/edit handlers.
- Update `permissions.py` comment and all `permissions_test.py` assertions to
  use the new `mcp__copilot__read_tool_result` name.
2026-04-13 04:27:17 +00:00
majdyz
45f96d5769 fix(copilot): wrap baseline turn-start drain in try/except; add 404/429 to OpenAPI spec
Baseline turn-start drain_pending_messages was unprotected — a transient
Redis error would propagate up and kill the entire turn stream, unlike the
already-protected mid-loop and SDK paths. Wrap with try/except + fallback
to [] so a Redis hiccup degrades gracefully.

Also adds 404 (session not found) and 429 (rate-limit exceeded) response
codes to the pending endpoint's OpenAPI spec so TypeScript clients can
handle these error paths correctly.
2026-04-13 04:24:29 +00:00
majdyz
5dbbdf9b27 fix(copilot): address round-6 review nits
- Remove redundant inner `ChatConfig` import in `_prewarm_cli` — it was
  already imported at module scope on line 16 (style guide: inner imports
  only for heavy optional deps)
- Correct stale comment in `sdk_compat_test.py`: 2.1.63/2.1.70 pre-date
  the context-management regression and are OpenRouter-safe without any
  env var; only 2.1.97+ requires CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS=1
- Update `_assert_no_forbidden_patterns` error message in
  `cli_openrouter_compat_test.py`: remove the stale "above 0.1.45" ceiling
  (we've already upgraded to 0.1.58) and point at the correct remediation
  steps (add to _KNOWN_GOOD_BUNDLED_CLI_VERSIONS after bisect verification)
- Plug test coverage gap in `env_test.py`: add
  `CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS == "1"` assertions to three
  OpenRouter test methods that were missing it
  (test_strips_trailing_v1, test_strips_trailing_v1_and_slash,
  test_no_v1_suffix_left_alone) — guards against the env var being
  accidentally dropped from a code path that the main test didn't exercise
2026-04-13 04:23:54 +00:00
majdyz
e901b64bed fix(test): fix _handle_low_balance mock signature to accept positional args
The gated_processor fixture's fake_low_balance mock used **kwargs, but
production code calls _handle_low_balance with positional args via
asyncio.to_thread. This caused a silent TypeError caught by the broad
except handler, making the handle_low_balance assertion fail (0 calls
instead of 1). Updated mock to match the actual method signature.
2026-04-13 04:22:03 +00:00
majdyz
64c3ef45df chore: apply Prettier formatting to BuilderChatPanel files
Three files were flagged by the CI lint/format check — apply prettier
--write to bring them into compliance.
2026-04-13 04:15:37 +00:00
majdyz
77ed619613 fix(frontend/builder): add flowID to tool-call effect deps for correct navigation guard 2026-04-13 04:09:05 +00:00
majdyz
626fe17aac fix(orchestrator): resolve None future on swallowed errors; add missing tests
- Move tool_node_stats None guard before node_exec_future.set_result so
  that when on_node_execution returns None (swallowed by @async_error_logged),
  the future carries set_exception(RuntimeError) rather than set_result(None),
  giving the tracking system an accurate error state
- Remove redundant `tool_node_stats is not None` check that was dead code
  after the early-return guard was added
- Add explanatory comment in _charge_extra_iterations_sync docstring explaining
  why the block lookup is intentionally repeated rather than cached from
  _charge_usage (two separate thread-pool workers, no shared mutable state)
- Add assertion to test_on_node_execution_charges_extra_iterations_when_gate_passes
  verifying _handle_low_balance is called when extra_cost > 0
- Add test_on_node_execution_failed_ibe_sends_notification covering the
  FAILED + InsufficientBalanceError path in on_node_execution (lines 822-836)
  that was previously untested
2026-04-13 04:03:08 +00:00
majdyz
3b7e678b97 fix(frontend/builder): address round-5 review comments on BuilderChatPanel
- Add type="button" and focus-visible ring to Stop/Send buttons in PanelInput
- Add type="button" to Retry button in MessageList and Apply button in ActionList
- Fix MessageList to render plain text directly and only pass dynamic-tool parts
  to MessagePartRenderer (text parts were being misrouted through a tool renderer)
- Replace clearGraphSessionCacheForTesting export with _graphSessionCache for
  tests — avoids leaking test scaffolding into the production bundle
- Add toast notification in undo restore when target node was deleted between
  apply and undo (prevents silent no-op)
- Fix misleading test: remove red-herring mockNodes.push from 'no auto-send' test
  since the guard is isGraphLoaded===false, not the node array
- Add truncation-path coverage to helpers.test.ts (MAX_NODES/MAX_EDGES branches)
- Add deleted-node undo test to actionApplicators.test.ts
2026-04-13 04:01:42 +00:00
majdyz
c51471a9df fix(platform-cost): replace non-null assertion with nullish coalesce, add token total test assertions, add bucket skeleton
- UserTable: replace `cost_bearing_request_count!` non-null assertion with
  `?? 1` nullish coalesce — eliminates the TypeScript anti-pattern and
  guards against a theoretical divide-by-zero if the guard is refactored
- platform_cost_test: add assertions for `total_input_tokens` and
  `total_output_tokens` in test_returns_dashboard_with_data to cover the
  "Total Tokens" summary card computation path
- PlatformCostContent: add a h-32 skeleton placeholder for the cost-bucket
  histogram section so the loading state reflects the loaded layout more
  closely and reduces CLS
2026-04-13 04:00:25 +00:00
majdyz
10980f3799 fix(copilot): wrap SDK turn-start drain in try/except, deduplicate format calls, elevate context length constants
- sdk/service.py: wrap drain_pending_messages at turn start in try/except;
  a transient Redis error no longer kills the entire turn (baseline mid-loop
  drain was already protected, SDK was missed in round 5)
- baseline/service.py: pre-compute format_pending_as_user_message content
  once per drained message and reuse it for both session.messages and
  transcript_builder — eliminates the redundant second call per message
- routes.py: move _URL_LIMIT/_CONTENT_LIMIT out of the validator body into
  module-level _CONTEXT_URL_MAX_LENGTH/_CONTEXT_CONTENT_MAX_LENGTH so the
  contract limits are visible to tooling without reading the implementation
2026-04-13 03:57:54 +00:00
majdyz
4ea5cd5f7f fix(backend/copilot): address round-5 review comments
- Add "do NOT redirect to the Builder for credential setup" guardrail to
  run_agent description, making it symmetric with create_agent/edit_agent
- Scrub error message text from race-path warning logs; log only node IDs
  and field names to avoid leaking credential IDs/provider details
- Add code comment explaining the None-vs-filtering trade-off in
  _build_setup_requirements_from_validation_error
- Add E2E tests for structural-error fallback on both run and schedule
  paths (verifies ErrorResponse returned, not setup_requirements card)
2026-04-13 03:54:31 +00:00
majdyz
e0d5047974 test(copilot): plug two test coverage gaps found in round-5 review
- env_test: add missing CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS assertion
  to test_no_anthropic_key_overrides_when_openrouter_flag_true_but_no_key
  (the other three build_sdk_env test cases already assert it; this case
  was the only one that didn't, leaving the env-var injection unverified
  for the openrouter_active=False / no-key path)

- sdk_compat_test: add test_sdk_exposes_max_thinking_tokens_option
  parallel to the existing test_sdk_exposes_cli_path_option — guards
  against a future SDK rename/removal of max_thinking_tokens silently
  disabling the Opus thinking-token cost cap
2026-04-13 03:53:41 +00:00
majdyz
ac0d939dd2 fix(copilot): address round-5 review — path leaks, Read partial truncation, concurrent edit test
- Use `file_path` (caller-supplied) instead of `resolved` in Write/Edit
  success messages to avoid leaking `/tmp/copilot-<session>/...` to the LLM
- Add partial-truncation guard to `_read_file_handler` (MCP `Read` tool):
  when `offset`/`limit` are present but `file_path` is missing, return a
  specific truncation message instead of the generic `file_path is required`
- Add `TestConcurrentEditLocking` test that uses `asyncio.gather` to verify
  two parallel Edit calls on the same file are serialised by `_edit_locks`
- Add `autouse` fixture `_clear_edit_locks` to prevent module-level dict
  from bleeding between test runs
2026-04-13 03:53:17 +00:00
majdyz
929718768a fix(platform-cost): avg stats use unfiltered agg to stay nonzero when tracking_type filtered
When a caller filters the dashboard by tracking_type='tokens', total_agg_groups
only contains tokens rows so cost_bearing_requests=0 and avg_cost_microdollars_per_request
silently returned 0.0. Symmetrically, filtering by cost_usd gave zero token averages.

Add a parallel total_agg_no_tracking_type_groups query (using where_no_tracking_type,
mirroring the fix already applied to by_user_tracking_groups) and derive avg_cost_total,
avg_input_total, avg_output_total, cost_bearing_requests, and token_bearing_requests
from that unfiltered aggregate. The displayed grand totals (total_cost, total_requests,
total_input_tokens) remain scoped to the active filter.

Also adds test_global_avg_cost_nonzero_when_filtering_by_tokens to cover this case.
2026-04-13 03:42:18 +00:00
majdyz
34832ca70c test(backend): compat-test the exact preset dict sent to ClaudeAgentOptions
The existing compat test for SystemPromptPreset omitted exclude_dynamic_sections,
diverging from the actual dict _build_system_prompt_value produces. The new test
calls the production helper directly and passes its output through ClaudeAgentOptions,
so any SDK version that rejects the extra key is caught at test time.
2026-04-13 03:34:19 +00:00
majdyz
4cd955c758 test: add tests for cost_bearing_request_count fix and tracking_type filter isolation
Two new tests in TestGetPlatformCostDashboard:
1. test_cost_bearing_request_count_nonzero_when_filtering_by_tokens: verifies
   that cost_bearing_request_count per user is correct even when the main
   tracking_type filter is 'tokens' (regression guard for the bug where
   by_user_tracking_groups used the filtered where-clause).
2. test_user_tracking_groups_excludes_tracking_type_filter: verifies that the
   3rd group_by call (by_user_tracking_groups) does NOT receive a trackingType
   constraint while the 1st call (by_provider) does.
2026-04-13 03:14:27 +00:00
majdyz
b7f1173cc4 fix: cost_bearing_request_count always 0 when filtering by non-cost_usd tracking type
When the caller filters the main view by e.g. tracking_type=tokens, the
by_user_tracking_groups query was also filtered, excluding all cost_usd rows
and making cost_bearing_request_count zero for every user. Use a separate
where_no_tracking_type filter (omitting tracking_type) for this sub-query so
cost_usd rows are always present for correct per-user avg cost denominators.
2026-04-13 02:58:08 +00:00
majdyz
88994a62ab fix(openapi): restore original formatting, insert CostBucket in alphabetical position 2026-04-13 02:44:44 +00:00
majdyz
bd7db8ff03 fix(openapi): move CostBucket schema to alphabetical position, fix cost_buckets field order 2026-04-13 02:37:35 +00:00
majdyz
8babdfe12f fix(frontend-test): use getAllByText for Known Cost which appears in both card and table header
Co-Authored-By:
2026-04-13 02:26:18 +00:00
majdyz
91882be590 fix(type-check): construct CostBucket TypedDict instances to satisfy Pyright
Pyright rejects `list[dict[str, Unknown]]` being passed as `list[CostBucket]`
because list is invariant. Constructing CostBucket instances explicitly
satisfies the type checker across Python 3.11/3.12/3.13.
2026-04-13 02:18:56 +00:00
majdyz
639b69b9d9 fix(api-types): add CostBucket as named schema; fix generated TS model path
- Add CostBucket to openapi.json components/schemas so orval generates
  a costBucket.ts file instead of an inline anonymous type
- Use \$ref in cost_buckets items array for proper orval code generation
- Create costBucket.ts generated model; update platformCostDashboard.ts
  to import from it instead of defining CostBucket inline
- Update PlatformCostContent.tsx import to use costBucket directly
2026-04-13 02:17:57 +00:00
majdyz
33ff46e96a style(frontend): apply prettier formatting to PlatformCostContent and openapi.json 2026-04-13 02:15:58 +00:00
majdyz
fbb93e2ddf fix(frontend-test): update renders empty dashboard assertion for 7 zero-cost cards
The PR added 5 new cost summary cards (Avg Cost, P50, P75, P95, P99)
that also display \$0.0000 when empty, so the test assertion needed to
change from 2 to 7 matching elements.
2026-04-13 02:15:28 +00:00
majdyz
187b4596e0 fix(platform-cost): fix per-user avg cost denominator, NULL bucket, tracking_type filter gap
- Add `cost_bearing_request_count` to `UserCostSummary` via a new
  group-by-(userId,trackingType) query; `UserTable` now divides by
  this count instead of the mixed `request_count`, eliminating
  denominator dilution for users with both tokens and cost_usd rows
- Guard histogram CASE against NULL costMicrodollars (NULL < N → unknown
  falls to ELSE '$10+'); add `AND "costMicrodollars" IS NOT NULL` to
  the histogram WHERE so NULL rows are excluded instead of bucketed
- Respect the `tracking_type` dashboard filter in raw SQL percentile
  and bucket queries; previously the filter was hardcoded to 'cost_usd'
  even when the caller passed tracking_type='tokens', making those
  queries return inconsistent data relative to the ORM queries
- Add p75 and p99 assertions to test_returns_dashboard_with_data
- Update openapi.json and generated TS model for new field
2026-04-13 02:10:40 +00:00
majdyz
f6f70e1c15 fix(backend): add SystemPromptPreset compat test, move inline import to top level
- sdk_compat_test.py: add test_agent_options_accepts_system_prompt_preset_dict
  to guard against SDK upgrades breaking the dict-variant system_prompt path
  introduced by cross-user prompt caching
- service_test.py: move `from backend.copilot import config as cfg_mod` to
  top-level imports (AGENTS.md: no local/inner imports)
2026-04-13 02:07:46 +00:00
majdyz
ac973396a2 fix(backend): use typing_extensions.TypedDict for Python < 3.12 compat 2026-04-13 01:46:12 +00:00
majdyz
8c50cb5fbc fix(platform-cost): correct token avg denominator; add CostBucket type; update generated TS
- avg_input/output_tokens_per_request now divide by token_bearing_requests
  (rows where trackingType='tokens') instead of cost_bearing_requests
  (rows where trackingType='cost_usd'). These are different DB rows: LLM
  calls log tokens under 'tokens' type, not 'cost_usd', so the old
  denominator was wrong and produced inflated averages.
- Add CostBucket TypedDict to platform_cost.py; replace list[dict] with
  list[CostBucket] for type safety on cost_buckets field.
- Update openapi.json cost_buckets item schema with explicit bucket/count
  properties so orval generates a typed interface instead of object.
- Import CostBucket in PlatformCostContent.tsx and use it instead of
  inline anonymous type on the .map() callback.
- Add test assertions for avg_input/output_tokens_per_request and
  avg_cost_microdollars_per_request in test_returns_dashboard_with_data
  to lock in the correct denominator behaviour.
2026-04-13 01:40:02 +00:00
majdyz
b7bec5d352 fix(backend): align _SystemPromptPreset with SDK shape, drop unused monkeypatch
- Make append and exclude_dynamic_sections NotRequired to match the SDK's
  SystemPromptPreset (append is NotRequired[str] in SDK; exclude_dynamic_sections
  is absent in 0.1.45 and will be optional once #12747 bumps to >=0.1.58)
- Remove redundant monkeypatch fixture from test_default_config_is_enabled
  (_clean_config_env already owns the monkeypatch instance)
2026-04-13 01:39:42 +00:00
majdyz
0be1d7ddbc fix(platform-cost): apply all dashboard filters to raw SQL percentile/bucket queries
Raw SQL queries for percentile and histogram bucket distributions were only
filtering by createdAt >= start, ignoring end, provider, user_id, model, and
block_name. This caused percentile/bucket stats to silently include data
outside the selected filter scope.

Also: add P75 summary card to frontend, fix skeleton count (8->12) to match
rendered card count, fix "Avg Cost / Request" subtitle to accurately say
"cost-bearing requests", and add test assertion that raw queries receive active
filter params.
2026-04-13 01:23:26 +00:00
majdyz
cd8079dba2 test: add CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE to _CONFIG_ENV_VARS
test_default_config_is_enabled uses _clean_config_env to ensure env
vars don't pollute the ChatConfig constructor test.  The new
claude_agent_cross_user_prompt_cache field reads from
CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE, but that var was missing from
the list — leaving the test non-deterministic if that env var is set in CI.
2026-04-13 01:23:08 +00:00
majdyz
df989853c1 Fix event loop error in TestGetPlatformCostDashboard tests
The raw SQL queries (query_raw_with_schema) added for percentile/bucket
computation were not mocked in unit tests, causing the real Prisma client
to be invoked. Its asyncio primitives were bound to a different event
loop than the test's, producing RuntimeError.

Mock query_raw_with_schema in all four TestGetPlatformCostDashboard tests
and add assertions for the new percentile/bucket dashboard fields.
2026-04-13 01:05:20 +00:00
majdyz
0ab7c9852c fix: remove accidentally committed worktree, add to gitignore 2026-04-13 00:54:11 +00:00
majdyz
fa6cc99a8a fix(backend): format service.py and test files 2026-04-13 00:54:01 +00:00
majdyz
348cdac328 feat(platform-cost): add cost percentile and distribution stats to dashboard
Add p50/p75/p95/p99 cost percentiles using PostgreSQL percentile_cont()
and histogram bucket distribution to PlatformCostDashboard. Display
percentile summary cards and cost bucket grid in the frontend.
2026-04-13 00:52:20 +00:00
majdyz
54f507b54b fix(backend): address PR review — extract testable helper, add TypedDict, rename config field
- Extract _build_system_prompt_value() helper so tests exercise
  production code instead of reconstructing the dict locally.
- Add _SystemPromptPreset TypedDict for proper type annotation
  (replaces str | dict[str, Any]).
- Rename claude_agent_exclude_dynamic_sections →
  claude_agent_cross_user_prompt_cache for clarity.
2026-04-13 00:48:37 +00:00
majdyz
c4e48b5c71 perf(backend): enable cross-user prompt caching via SystemPromptPreset
Use SystemPromptPreset with exclude_dynamic_sections=True in the SDK
path so the Claude Code default prompt serves as a cacheable prefix
shared across all users. Our custom prompt is appended after it, and
dynamic sections (working dir, git status, auto-memory) are excluded
from the prefix -- giving cross-user cache hits that reduce input
token cost by ~90%.

Add claude_agent_exclude_dynamic_sections config field (default True)
to make this configurable, with fallback to raw string when disabled.
2026-04-13 00:39:30 +00:00
majdyz
497cc15a8b fix(backend): update guardrail tests for new defaults
Update test assertions to match new defaults:
- max_turns: 1000 → 50
- max_budget_usd: 100.0 → 5.0
- Add test for max_thinking_tokens default (8192)
2026-04-13 00:22:50 +00:00
majdyz
b044862dba perf(copilot): add thinking token cap and lower default budget/turns
Cost investigation showed 54% of spend is thinking tokens (~57K/turn
avg at $75/M for Opus). Add max_thinking_tokens config (default 8192)
to cap extended thinking output per LLM call.

Also lower defaults:
- max_budget_usd: $100 → $5 per turn
- max_turns: 1000 → 50 tool-use loops

These are configurable via env vars (CHAT_CLAUDE_AGENT_MAX_THINKING_TOKENS,
CHAT_CLAUDE_AGENT_MAX_BUDGET_USD, CHAT_CLAUDE_AGENT_MAX_TURNS).
2026-04-13 00:12:16 +00:00
majdyz
cdd566c9d1 fix(backend): format platform_cost_test.py 2026-04-13 00:11:26 +00:00
majdyz
2a0ff06f9c fix(platform/admin): correct avg cost/token denominators, UI formatting fixes
- Filter avg cost and token denominators to cost-bearing requests only
  (trackingType == "cost_usd") to prevent dilution by non-cost rows
- Remove Math.round before formatMicrodollars in UserTable to preserve
  sub-dollar precision
- Use nullish coalescing (?? 0) instead of falsy check for avg_cost
- Update skeleton placeholder count from 4 to 8 to match actual cards
2026-04-13 00:02:59 +00:00
majdyz
2084e6e06e feat(platform/admin): enhance cost dashboard with token breakdown and per-request averages
Add deeper cost visibility to the admin platform cost dashboard:
- Show prompt vs completion tokens separately in provider table
- Add summary cards for avg cost/request, avg input/output tokens, total token split
- Add avg cost per request column to the per-user table
- Compute aggregate token totals and per-request averages in backend
2026-04-12 23:51:04 +00:00
majdyz
ed65756d58 fix(frontend): reset isCreatingSession in retrySession and cap parsedActions
Add setIsCreatingSession(false) at the start of retrySession so the
spinner is cleared when retrying during an in-flight session creation.

Add MAX_PARSED_ACTIONS=100 cap to trim the oldest entries from the
accumulated action list, preventing unbounded growth in long conversations.
2026-04-12 23:22:08 +00:00
majdyz
099d5cf1b2 test(copilot): assert CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS in subscription and direct-Anthropic modes
Add the assertion to TestBuildSdkEnvSubscription.test_returns_blanked_keys
and TestBuildSdkEnvDirectAnthropic.test_no_anthropic_key_overrides_when_openrouter_inactive.
2026-04-12 23:18:56 +00:00
majdyz
057412ebee fix(copilot): allow exactly 30 pending calls per window
Change >= to > so the 30th call (INCR returns 30) is accepted
and only the 31st triggers the 429.
2026-04-12 23:14:54 +00:00
majdyz
5f92082f9c fix(backend/copilot): harden system prompt to distrust user_context on turn 2+
The system prompt previously told the LLM to use <user_context> blocks
"when the user provides" them, which could let a turn-2+ injection slip
past even after the server-side strip. The prompt now explicitly states
that <user_context> is server-injected, only appears on the first
message, and must be ignored on subsequent messages.

Combined with the strip_user_context_tags() sanitization (applied
unconditionally to every incoming message in both SDK and baseline
paths), this provides defence-in-depth against prompt injection via
fake user context.
2026-04-12 12:58:12 +00:00
majdyz
f07143c5ea fix(backend/copilot): strip <user_context> tags from all user messages
The sanitization was only applied on the first turn (guarded by
`not has_history` / `is_first_turn`), allowing users to inject fake
`<user_context>` blocks on turn 2+ that the LLM would trust.

Add `strip_user_context_tags()` to the shared service module and call
it on every incoming user message in both SDK and baseline paths,
before the message is stored or forwarded to the LLM.
2026-04-12 12:36:00 +00:00
majdyz
b62288655f Add None guard for tool_node_stats after on_node_execution
If on_node_execution returns None, return an error response with
_is_error=True instead of falling through to the success path.
2026-04-12 12:10:43 +00:00
majdyz
f3f598daa3 Wrap mid-loop drain_pending_messages in try/except
If the Redis drain fails mid-tool-loop, log a warning and treat it as
no pending messages rather than crashing the entire copilot turn.
2026-04-12 12:10:05 +00:00
majdyz
44aa676fa5 fix(frontend): update stale assertion in applyConnectNodes duplicate-edge test
After the double-call fix removed the direct setAppliedActionKeys call
from the alreadyExists branch, the test still expected it to be called
once. Updated to .not.toHaveBeenCalled() since the caller
(handleApplyAction) now handles marking applied keys.
2026-04-12 12:09:06 +00:00
majdyz
c4d8293fad Add explicit CRED_ERR_INVALID_TYPE_MISMATCH to _CREDENTIAL_ERROR_MARKERS
Previously covered accidentally by the CRED_ERR_INVALID_PREFIX prefix
rule. Adding an explicit exact-match entry makes the intent clear and
prevents breakage if the prefix constant ever changes.
2026-04-12 12:05:37 +00:00
majdyz
2c2fadba47 fix(backend): add env var test coverage and fix stale comments
- Add CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS assertion in env_test.py
- Fix stale references to service.py → build_sdk_env() in env.py
2026-04-12 12:04:15 +00:00
majdyz
5d7fa7c216 fix(backend): update test to use PendingMessageContext attribute access
context is now a PendingMessageContext object, not a dict — use
.url attribute instead of ["url"] subscript.
2026-04-12 11:42:06 +00:00
Zamil Majdy
4ccfec589b fix(backend): use mutating annotation for E2B write/edit tools and remove phantom tool names
- Apply _MUTATING_ANNOTATION (readOnlyHint=False) to E2B write_file and
  edit_file tools to prevent unsafe parallel dispatch of file-mutating
  operations in the sandbox
- Remove WRITE_TOOL_NAME and EDIT_TOOL_NAME from get_copilot_tool_names
  E2B branch since those unified tools are not registered in E2B mode
  (E2B uses write_file/edit_file from E2B_FILE_TOOLS instead)
2026-04-12 11:23:08 +00:00
majdyz
7b783aa03b fix(backend): use PendingMessageContext type in QueuePendingMessageRequest to prevent 500
Change context field from dict[str,str] to PendingMessageContext so
Pydantic validates (including extra="forbid") at request parse time,
returning a proper 422 instead of an unhandled ValidationError / 500
when the caller sends unexpected keys.
2026-04-12 11:21:23 +00:00
majdyz
2704e43d42 fix(backend): move CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS into build_sdk_env
Moves the env var injection from service.py into build_sdk_env() in
env.py so all callers (including orchestrator.py) get it automatically.
Also changes xfail(strict=False) to strict=True so CI fails if the
upstream fix lands and we can remove the workaround.
2026-04-12 11:17:56 +00:00
majdyz
a35e9a2b4c fix(backend): fix CI failures from proxy removal
- cli_path validation: only check os.access(X_OK) when path exists
  (tests use non-existent paths to verify env var resolution)
- Mark bare CLI test as xfail since CLI 2.1.97 sends the beta header
  without the env var — the env var test is the real regression guard
2026-04-12 11:10:48 +00:00
majdyz
6bad358d78 fix(frontend/builder): address review comments on useBuilderChatPanel
- Clear graphSessionCache on user change to prevent session leaks across sign-outs
- Reset all per-session state in retrySession including skipNextParseRef/skipNextToolScanRef
- Trim whitespace in sendRawMessage before empty guard
- Remove duplicate setAppliedActionKeys call from alreadyExists branch in applyConnectNodes
2026-04-12 11:01:34 +00:00
majdyz
8e9bb083b2 refactor(backend): replace compat proxy with CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS env var 2026-04-12 10:53:25 +00:00
majdyz
5e8d3ba889 fix(backend/executor): harden orchestrator tool execution error handling
- Replace assert with proper if-guard + RuntimeError for node_exec_result
- Wrap on_node_execution in try/except to always resolve the Future via
  set_exception on error, preventing dangling unresolved Futures
- Add separate try/except around charge_node_usage so non-balance billing
  exceptions propagate instead of being silently swallowed by the outer
  except Exception handler
- Add _is_error flag to tool response dicts to replace fragile string
  matching on "Tool execution failed:" prefix
2026-04-12 10:17:09 +00:00
majdyz
db9eb29138 fix(backend): address review findings for pending-message endpoint
- Fix off-by-one in rate limit: use >= instead of > for call count check
- Move track_user_message() after push_pending_message() so analytics
  only fires on successful push
- Add logger.warning in rate-limiter except-Exception catch instead of
  silent pass
- Use fullmatch instead of match for UUID regex validation
- Add extra="forbid" to PendingMessageContext to reject unexpected fields
2026-04-12 10:13:45 +00:00
majdyz
e92ecbbb7c fix(backend): address review comments on SDK upgrade PR
- Make strip_forbidden_betas_from_body non-mutating (returns shallow
  copy instead of modifying caller's dict in-place)
- Add os.access(X_OK) validation for claude_agent_cli_path to reject
  non-executable paths at config load time
- Replace hardcoded /v1 path dedup with generic urlparse-based logic
  that handles any API version prefix in the target URL
2026-04-12 10:08:29 +00:00
majdyz
ff32fa2772 fix(backend): update test_read_builtin_blocked for workspace-scoped Read
Read is now workspace-scoped (allowed within sdk_cwd, denied outside).
Split the old test into two: test_read_within_workspace_allowed and
test_read_outside_workspace_blocked.
2026-04-12 10:06:38 +00:00
majdyz
d7d9b5ea91 fix(backend): address review comments on unified file tools PR
- Add _MUTATING_ANNOTATION (readOnlyHint=False) for Write/Edit tools to
  prevent unsafe parallel dispatch of file-mutating operations
- Fix non-atomic lock creation with setdefault instead of check-then-set
- Remove racy lock eviction after async with block
- Gate unified Write/Read/Edit behind not use_e2b to prevent duplicate
  tool registration (E2B already has write_file/read_file/edit_file)
- Remove unused use_e2b param from get_*_tool_handler functions
2026-04-12 09:45:36 +00:00
majdyz
ea4d936886 fix(backend): address review comments on credential setup PR
- Add credential-routing guardrail to create_agent and edit_agent tool
  descriptions so it's always visible to the LLM (not just in the guide)
- Change issubclass to identity check for GraphValidationError in
  _handle_call_method_response to avoid fragile subclass dispatch
- Add new_callable=AsyncMock for consistency in execution race test
- Move orjson and RemoteCallError imports to module level in service_test
- Add Literal type annotation to _CREDENTIAL_ERROR_MARKERS match mode
2026-04-12 09:40:31 +00:00
majdyz
7f782d4676 ci(backend): add test to validate CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS env var
Adds a new test that spawns the CLI with
CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS=1 (without the compat proxy) and
checks whether the context-management-2025-06-27 beta header is
stripped. If this test passes in CI, the proxy can be removed entirely
in favour of the simpler env var approach.
2026-04-12 09:33:11 +00:00
majdyz
ab07e55635 fix(backend): allow Read tool for workspace-scoped paths (tool-results/tool-outputs)
The security hooks were blocking Read unconditionally because it was in
BLOCKED_TOOLS. However, the SDK needs Read to access tool-results/ and
tool-outputs/ directories for oversized result handling. Fix by adding
Read to WORKSPACE_SCOPED_TOOLS and checking workspace scope before the
blocked-tools list, so Read is allowed within the workspace but still
blocked for arbitrary paths.
2026-04-12 09:05:54 +00:00
majdyz
c228b2c4c6 fix(copilot): fix pyright bytes/bytearray type error in e2b read 2026-04-12 08:42:21 +00:00
majdyz
d3a5bdb580 refactor(copilot): consolidate file tools into single e2b_file_tools.py
Delete file_tools.py and file_tools_test.py, extending e2b_file_tools.py
to handle both E2B (sandbox) and non-E2B (local SDK working dir) modes.

Each handler checks _get_sandbox(): if a sandbox exists, it routes to
the E2B filesystem; otherwise it falls back to the SDK working directory
with path validation, truncation detection, per-path edit locking, and
large content warnings.

All 104 tests from both files now live in e2b_file_tools_test.py.
tool_adapter.py imports exclusively from e2b_file_tools.
2026-04-12 08:37:55 +00:00
majdyz
98f0ddd99d fix(copilot): disallow built-in Read in non-E2B mode for consistency
Read was already disallowed in E2B mode (prod/dev) via
_SDK_BUILTIN_FILE_TOOLS and has been working fine — the LLM uses our
MCP read_file which handles tool-results paths via
is_allowed_local_path(). Extend the same disallow to non-E2B for
consistency: all file I/O now goes through our MCP tools regardless
of mode.

Also removes Read from WORKSPACE_SCOPED_TOOLS since it's now fully
disallowed (was only there as a workspace-path validator for the
CLI built-in Read which is no longer available).
2026-04-12 08:06:34 +00:00
majdyz
05477f2daa fix(copilot): address third CodeRabbit review cycle on proxy
- Preserve multi-valued response headers (e.g. Set-Cookie) by using
  clean_response_headers -> CIMultiDict instead of dict(headers)
- Use sock_connect + sock_read timeouts instead of total so long-lived
  SSE streaming responses aren't killed after 600s
- Log the configured bind_host instead of hardcoded 127.0.0.1
2026-04-12 07:43:12 +00:00
majdyz
cc3bac13c5 fix(copilot): address second CodeRabbit review cycle
- Fix docstring: default is True, not False (config_test.py)
- Redact exception message from stream-error log for consistency
  with upstream-error log (openrouter_compat_proxy.py)
2026-04-12 07:31:45 +00:00
majdyz
85f64de4cf fix(copilot): address review feedback on compat proxy PR
- Redact upstream URL from proxy error log to prevent leaking internal
  hostnames (openrouter_compat_proxy.py line 431)
- Remove type: ignore suppressors from cli_openrouter_compat_test.py,
  using cast instead for the untyped SDK import
- Fix validator precedence: replace field_validator with model_validator
  so explicit ChatConfig(claude_agent_use_compat_proxy=False) is not
  overridden by the unprefixed CLAUDE_AGENT_USE_COMPAT_PROXY env var
- Add regression test for explicit-kwarg precedence
2026-04-12 07:23:41 +00:00
majdyz
59ee9efc3a fix(copilot): remove required from MCP schemas to fix truncation detection
The MCP SDK validates "required" fields BEFORE calling the Python handler.
When the LLM's output tokens are truncated, the tool call arrives as {}
and the SDK rejects it with an opaque "'file_path' is a required property"
error — our truncation detection code never fires.

Changes:
- Remove "required" arrays from all MCP tool schemas (file_tools,
  e2b_file_tools, tool_adapter Read, _build_input_schema)
- Add empty-args truncation detection to all E2B handlers (read_file,
  write_file, edit_file, glob, grep) and tool_adapter Read handler
- Add partial truncation detection to E2B write_file and edit_file
- Add offset/limit int() parsing safety to E2B read_file handler
- Sanitize write error-path output (type name only, not full exc)
- Remove Write/Edit from WORKSPACE_SCOPED_TOOLS (conflict with
  SDK_DISALLOWED_TOOLS)
- Fix E2B mode read_file double registration in get_copilot_tool_names
- Update schema tests to assert "required" is absent
2026-04-11 23:37:34 +00:00
majdyz
788c163d50 fix(copilot): restore required keys in Write and Read tool schemas
Previous commit accidentally dropped the "required" key from
WRITE_TOOL_SCHEMA and READ_TOOL_SCHEMA.  Restores them to prevent
MCP SDK from accepting incomplete tool calls.
2026-04-11 23:32:01 +00:00
majdyz
7bbfbda49c fix(copilot): evict per-path edit locks after use to prevent memory leak
Clean up _edit_locks entries after the edit completes when no other
coroutine is waiting on the same path.  Prevents unbounded growth of
the lock dictionary in long-running server deployments.
2026-04-11 23:17:29 +00:00
majdyz
e4c044913a fix(copilot): add read_file truncation detection and Edit per-path lock
- Add truncation recovery to read_file: when offset/limit are present
  but file_path is missing, return actionable truncation error instead
  of generic "file_path is required"
- Add per-path asyncio lock around Edit's read-modify-write cycle to
  prevent parallel edits on the same file from silently dropping changes
- Add tests for read truncation detection
2026-04-11 23:07:34 +00:00
majdyz
f913c52980 fix(copilot): address review feedback on unified file tools
- Sanitize error-path output in write failures: use os.path.basename()
  instead of exposing full resolved path in error messages
- Validate offset/limit parsing in read_file: wrap int() calls in
  try/except to return clean error on non-integer input
- Fix duplicate read_file registration in E2B mode: skip unified
  read_file when use_e2b=True since E2B_FILE_TOOLS already registers it
- Add tests for invalid offset/limit input
2026-04-11 22:55:14 +00:00
majdyz
e87111e5eb fix(copilot): update security test for Edit tool blocking
The Edit tool is now blocked in SDK_DISALLOWED_TOOLS (same as Write),
so the security hooks test must assert denial instead of allowing it.
2026-04-11 17:59:03 +00:00
majdyz
0fa24c8332 feat(copilot): unified MCP Read and Edit tools to prevent truncation data loss
Add read_file and Edit MCP tools following the same pattern as the
existing unified Write tool.  Both route to the E2B sandbox when active
and fall back to the SDK working directory in non-E2B mode.

Read tool (read_file):
- Reads files with cat -n formatted line numbers
- Supports offset/limit for large files
- Binary file detection by extension
- Path validation via is_allowed_local_path()
- CLI built-in Read is NOT disabled (used internally for oversized
  tool results)

Edit tool:
- Targeted find-and-replace with old_string/new_string
- replace_all flag for multi-occurrence replacements
- Uniqueness check when replace_all=false
- Partial truncation detection with actionable guidance
- CLI built-in Edit IS disabled in SDK_DISALLOWED_TOOLS

Both tools are registered in create_copilot_mcp_server() and included
in get_copilot_tool_names() for both E2B and non-E2B modes.
2026-04-11 17:57:45 +00:00
majdyz
82018770c2 fix(copilot): update security test for Write tool blocking, fix pre-existing lint
- Update test_write_within_workspace_allowed -> test_write_builtin_blocked
  to reflect that SDK built-in Write is now intentionally blocked
- Fix pre-existing black formatting in platform_cost_test.py
2026-04-11 17:37:32 +00:00
majdyz
26f91fd6c4 style(copilot): fix isort import formatting in file_tools.py 2026-04-11 17:24:12 +00:00
majdyz
e759e14feb refactor(copilot): extract truncation check to shared helper, sanitize error paths
- DRY: extract _check_truncation() shared between E2B and non-E2B handlers
- Security: use os.path.basename() in path validation errors to avoid
  leaking internal directory structure
2026-04-11 17:19:08 +00:00
majdyz
9f0ade1642 fix(copilot): unified MCP Write tool to prevent truncation data loss
Replace the CLI's built-in Write tool with a unified MCP Write tool that
detects and handles output-token truncation gracefully instead of losing
user work with opaque "'file_path' is a required property" errors.

The new tool:
- Works in both E2B and non-E2B modes
- Detects partial truncation (content present but file_path missing)
- Detects complete truncation (empty args) with actionable guidance
- Warns on large content (>50K chars) that succeeded
- Places file_path first in schema to survive truncation better
- Validates paths stay within the SDK working directory

Also blocks the CLI built-in Write via SDK_DISALLOWED_TOOLS and
strengthens the system prompt's large-file writing guidance.

Fixes production truncation bug reported in sessions:
  1a0ef47f-9711-41d2-91b8-df9dff9ecfc6
  2bb9fb0d-05bd-4194-8f3f-88fbe3d8b965
  8226d82e-bf9c-48e3-826f-80048d0fb7b4
2026-04-11 17:16:16 +00:00
majdyz
c70e34c30e fix(backend/copilot): prevent duplicate assistant text after mid-loop pending drain
Track _flushed_assistant_text_len on _BaselineStreamState so the finally
block only appends assistant text produced AFTER the last mid-loop flush.
Without this, state.assistant_text (all rounds) vs state.session_messages
(post-flush only) desync caused the startswith(recorded) dedup to fail,
duplicating round-1 assistant text in session.messages.

Adds regression test in service_unit_test.py.
2026-04-11 15:00:25 +00:00
majdyz
d49ffac0a1 fix(backend/copilot): flush buffered rounds before mid-loop pending drain and wrap turn-start persist
Address three review comments on the pending-message PR:

1. (Blocker) Mid-loop pending drain now flushes state.session_messages
   into session.messages before appending the pending user message, so
   assistant+tool entries from completed rounds land in chronological
   order. Without this, the next turn's replay could hit OpenAI tool-call
   ordering errors (user message interposed between assistant tool_call
   and its tool result).

2. (Should-Fix) Turn-start upsert_chat_session wrapped in try/except so
   a transient DB failure doesn't silently lose messages already popped
   from Redis.  Matches the pattern used in mid-loop and SDK drain paths.

3. (Nice-to-Have) Added TestMidLoopPendingFlushOrdering regression test
   in service_unit_test.py that replays the production flush sequence
   and asserts chronological ordering of assistant/tool/pending entries.
2026-04-11 14:55:46 +00:00
majdyz
2087e36350 fix(util/service): type RemoteCallError.extras and add server-side round-trip tests
Two related issues on the GraphValidationError-over-RPC path:

1. RemoteCallError.extras was typed as Optional[dict[str, Any]], which
   lost all validation guarantees at the wire-format boundary. Today
   the field only carries node_errors (a known-JSON-safe nested dict),
   but Any left the door open for a future exception to stuff a
   non-serializable payload in — blowing up inside FastAPI's JSON
   encoder at handler time instead of at test time. Introduce a typed
   RemoteCallExtras Pydantic model with an explicit
   node_errors: Optional[dict[str, dict[str, str]]] field and update
   both the server-side packer and client-side unpacker to use it.
   Non-serializable sneak-ins now fail at model validation instead of
   inside JSONResponse.

2. The previous tests only exercised the client-side deserialisation
   with a mocked wire payload — there was no symmetric test that the
   server handler actually packs node_errors into extras. If someone
   accidentally dropped the isinstance(exc, GraphValidationError)
   branch in _handle_internal_http_error, the client tests would still
   pass because they forged the payload directly. Add
   test_graph_validation_error_server_handler_packs_node_errors that
   invokes the server handler with a real GraphValidationError and
   validates the resulting JSONResponse body, plus
   test_graph_validation_error_round_trips_through_handlers that wires
   the server handler output through the client handler end to end —
   either side drifting now fails this test.
2026-04-11 12:07:19 +00:00
majdyz
1007e03b20 refactor(executor/utils): extract credential error markers + parity tests
is_credential_validation_error_message was matched against hand-typed
strings that had to be kept in sync with the four raise sites inside
_validate_node_input_credentials by convention only. Adding a new
credential error string would have silently regressed the copilot's
credential-race fallback to a plain text error, with nothing to catch
the drift at test time.

Extract the four error templates to module-level constants
(CRED_ERR_REQUIRED, CRED_ERR_INVALID_PREFIX, CRED_ERR_NOT_AVAILABLE_PREFIX,
CRED_ERR_UNKNOWN_PREFIX, CRED_ERR_INVALID_TYPE_MISMATCH) used from
both raise sites and the matcher, with a _CREDENTIAL_ERROR_MARKERS
table that classifies each as exact-match or prefix-match. Add a
parity test that asserts each constant-emitted message is recognised
by the public matcher (including typical runtime suffixes from the
f-strings), plus a case-insensitive check and a negative test. Now
adding a new credential error means adding a constant, and the
test_credential_error_markers_cover_all_raise_sites test fails if the
matcher drifts.
2026-04-11 12:07:03 +00:00
majdyz
260bbb28bd fix(backend/copilot): log GraphValidationError race path for observability
Both _run_agent and _schedule_agent silently swallowed the
GraphValidationError that triggers the credential-race fallback and
returned the inline setup card with no log. That left the race invisible
to oncall — every recovered request looked identical to a cold-cache
first attempt, and credential-drift rates could not be monitored.

Add a logger.warning at both race catch sites that captures user_id,
graph_id, and the raw node_errors so SRE can track how often the
prereq check drifts from the executor/scheduler re-validation. Keeps
the user-facing behaviour unchanged — still returns the inline card —
but makes the race observable.
2026-04-11 12:06:47 +00:00
majdyz
2f24091c17 fix(platform): simplify stripe customer race protection
Revert the tentative update_many conditional guard (prisma where-clause
null semantics are fiddly and the test suite mocks get_stripe_customer_id
end-to-end, so a real prisma error wouldn't be caught locally). The
idempotency_key on Customer.create is sufficient: Stripe collapses
concurrent + retried calls to the same Customer object for 24h, which
comfortably covers every realistic in-flight retry window.

Also invalidate the get_user_by_id cache after the DB write so the
freshly-persisted stripeCustomerId is visible on the next read.
2026-04-11 12:00:58 +00:00
majdyz
8b93cea4d4 fix(platform): harden Stripe billing flow against race + replay edges
Address review findings on the subscription tier billing PR:

1. get_stripe_customer_id race: two concurrent calls (double-click,
   retried request) could each create a Stripe Customer for the same
   user, leaving an orphaned billable customer. Pass an idempotency_key
   so Stripe collapses concurrent + retried calls server-side, and use
   a conditional update_many so the loser of a longer-window race
   re-reads the persisted ID instead of overwriting.

2. update_subscription_tier no-op short-circuit: if the user is already
   on the requested paid tier, return without creating a Checkout
   Session. Without this guard, a duplicate request creates a second
   subscription for the same price; the user would be charged for both
   until _cleanup_stale_subscriptions runs from the resulting webhook —
   which only fires after the second charge has cleared.

3. stripe_webhook payload defensive extraction: a malformed payload
   (missing/non-dict data.object, missing id) would raise KeyError /
   TypeError after signature verification, which Stripe interprets as
   a delivery failure and retries forever. Validate shape, log a
   warning, and ack with 200 so Stripe stops retrying.

4. _cleanup_stale_subscriptions: bump the swallowed-error log from
   warning to exception so Sentry surfaces it as an error, include
   the customer/sub IDs needed for manual reconciliation, and add a
   TODO referencing the missing periodic reconcile job that the
   docstring already promises as the backstop.
2026-04-11 11:48:33 +00:00
majdyz
72660f8df0 test(orchestrator): use AsyncMock for charge_node_usage in remaining tests
charge_node_usage is async and is awaited in _execute_single_tool_with_manager.
Mocking it as MagicMock returned a non-awaitable tuple, which raised TypeError
inside the tool execution path; the error was silently swallowed by the
orchestrator's catch-all and converted into a 'Tool execution failed' string,
so downstream assertions effectively ran against an error response instead
of the success path. Switch both tests to AsyncMock to actually exercise the
post-tool charging branch — matching the fix already applied in
test_orchestrator.py:929.
2026-04-11 11:47:18 +00:00
majdyz
349daf48f7 test(copilot/config): flip default-compat-proxy test for dev preview
Dev-preview flips ``claude_agent_use_compat_proxy`` default to True
so the bundled CLI in claude-agent-sdk 0.1.58 works out of the box.
Update the no-env-var test accordingly so rebasing the upstream
config test on this branch doesn't fail.
2026-04-11 11:05:04 +00:00
majdyz
d702bcfae2 test(copilot/sdk-compat): skip bare-CLI reproduction when proxy enabled
When `claude_agent_use_compat_proxy=True` the operator has explicitly
opted into the workaround. The bare-CLI reproduction stops being a
useful signal in that mode — what matters is the *upstream* (post-
proxy) staying clean, which is covered by
`test_cli_via_compat_proxy_emits_clean_requests_to_upstream`.

Skip the bare test in that case so the dev-preview branch (0.1.58 +
proxy on) goes fully green instead of having an intentional-but-loud
failure on every CI run.

When the proxy is disabled (the default on the standalone proxy and
plumbing PRs), the bare test continues to run unchanged and serves
as the regression detector for the bundled CLI version.
2026-04-11 11:05:04 +00:00
majdyz
2af87616de test(copilot/sdk-compat): add proxy-routed reproduction variant
Adds `test_cli_via_compat_proxy_emits_clean_requests_to_upstream` so
the compat proxy has a real end-to-end regression guard: spawn the
bundled CLI through the proxy against a fake upstream, capture what
the upstream sees, assert it's clean.

The bare reproduction test
(`test_cli_does_not_send_openrouter_incompatible_features`) keeps
its original semantics — proves the bundled CLI is or isn't broken
upstream — so we still get a clean bisect signal when changing the
SDK pin.

Together the two tests give:

* bare CLI clean → bare test passes; proxy test passes (no-op).
* bare CLI broken → bare test fails (intentional bisect signal);
  proxy test passes if and only if the proxy successfully strips
  the forbidden patterns.

Which means on this dev preview branch (SDK 0.1.58 with proxy on),
CI catches both:

* the regression actually exists (bare test fails — that's the
  reproduction the user asked for), and
* the proxy actually fixes it (proxy test passes — that's the
  workaround validation).
2026-04-11 11:05:04 +00:00
majdyz
5cf60587ef chore(deps): bump claude-agent-sdk to 0.1.58 with compat proxy enabled
Dev preview PR — combines the cli_path plumbing (#12741), the
in-process compat proxy (#12745), and the SDK bump in one branch so
we can dogfood the full upgrade end-to-end.

Changes:

* `claude-agent-sdk` -> 0.1.58 (bundled CLI 2.1.97).  Gets us all the
  blocked features:
    - `exclude_dynamic_sections` cross-user prompt cache hits
      (0.1.57) — directly amplifies #12725
    - `AssistantMessage.usage` per-turn token tracking (0.1.49) —
      cost attribution
    - `task_budget` (0.1.51) — per-task cost ceiling
    - `get_context_usage()` (0.1.52) — context window monitoring
    - MCP large-tool-result truncation fix (0.1.55)
    - MCP HTTP/SSE buffer leak fix (CLI 2.1.97) — known production
      memory creep
    - 429 retry exponential-backoff fix (CLI 2.1.97) — production
      rate-limit recovery
    - `--resume` cache miss regression fix (CLI 2.1.90)
    - SDK session quadratic-write fix (CLI 2.1.90)

* `ChatConfig.claude_agent_use_compat_proxy` default flipped from
  `False` -> `True`. The bundled CLI in 0.1.55+ injects the
  `context-management-2025-06-27` beta header which OpenRouter
  rejects (anthropics/claude-agent-sdk-python#789). The proxy strips
  it transparently. Disable explicitly only if you've pinned to a
  CLI version in `_KNOWN_GOOD_BUNDLED_CLI_VERSIONS_DIRECT`.

* `sdk_compat_test.py` pin assertion split into two known-good sets:
    - `_KNOWN_GOOD_BUNDLED_CLI_VERSIONS_DIRECT` — works without the
      proxy ({"2.1.63", "2.1.70"})
    - `_KNOWN_GOOD_BUNDLED_CLI_VERSIONS_VIA_PROXY` — works only with
      the compat proxy enabled ({"2.1.97"})
  The test now requires `claude_agent_use_compat_proxy=True` for
  proxy-only versions, so disabling the proxy on a fresh checkout
  with this PR fails fast with a clear error.

Operational rollout (when ready to ship beyond dev preview):

1. Merge #12741 (plumbing + reproduction test)
2. Merge #12745 (proxy module — opt-in default off)
3. Merge this PR (bumps SDK + flips default to on)
4. Watch production for the existing reproduction test running
   continuously as a regression guard
5. If anything goes wrong: revert this PR (proxy becomes opt-in
   again, SDK back to whichever version is in the previous merge)

Dev preview usage: deploy this branch with no env-var changes —
the proxy is on by default. The reproduction test will continue
to pass against the bundled CLI 2.1.97 when (and only when) the
proxy successfully strips the forbidden patterns.
2026-04-11 11:05:03 +00:00
majdyz
428ed39a1a fix(copilot/sdk-proxy): abort transport on mid-stream upstream error
The previous fix set a ``stream_error`` flag and returned the
prepared ``StreamResponse`` without calling ``write_eof()``,
assuming aiohttp would leave the body dangling. It doesn't:
aiohttp's handler dispatcher finalises any returned
``StreamResponse`` on the way out (writing the chunked terminator /
content-length / EOF), so a regression test with a real mid-stream
failure still saw the client get a clean 200 body.

Correct fix: on the stream-error path, abort the underlying
transport directly via ``request.transport.abort()`` and then
re-raise the original stream error out of the handler. Aborting
drops the TCP socket mid-response so the client's parser surfaces
a ``ClientPayloadError`` / ``ServerDisconnectedError`` and the
caller sees the truncation as a real transport failure.

Also rewrote the regression test to use a raw
``asyncio.start_server`` TCP handler that sends a chunked response
header plus one partial chunk and then hard-closes the socket
(``transport.abort()``) — this is the one failure mode that
reliably propagates through aiohttp's ``iter_any()`` as a
``ClientError`` for the proxy to detect.  Verified locally: the
test now fails with the expected ``ClientPayloadError`` on the
client side instead of silently returning 200.
2026-04-11 11:04:50 +00:00
majdyz
8742c5e5b9 fix(copilot/sdk-proxy): treat empty sdk_env ANTHROPIC_BASE_URL as opt-out
Claude Code subscription mode intentionally sets
``sdk_env['ANTHROPIC_BASE_URL'] = ""`` to disable any base-URL
override and keep the CLI talking to Anthropic directly. The
previous ``or``-chained lookup evaluated the empty string as falsy
and fell through to ``os.environ.get("ANTHROPIC_BASE_URL")`` and
then to ``OPENROUTER_BASE_URL``, silently starting the compat proxy
for a session that had explicitly opted out — which breaks
subscription auth.

Use a presence check on ``sdk_env`` instead: if the key is present
with an empty value it's a hard "no-proxy" signal, so skip the
OpenRouter fallback even when ``openrouter_active`` is True. The
process-env fallback and the OpenRouter fallback still cover the
original cases (no sdk_env override, OpenRouter is the routing
provider for this session).

Flagged by sentry review on #12745 (thread 3067906804).
2026-04-11 10:48:04 +00:00
majdyz
370499c8dc fix(copilot/sdk-proxy): don't signal clean EOF on mid-stream error
When an ``aiohttp.ClientError`` fires mid-stream the previous code
logged it and then called ``downstream.write_eof()``, which tells
the downstream client "stream complete" on top of a partial,
truncated body. Clients then silently consumed the corrupt response
as if it were a clean success.

Track the stream error in a local variable and, when it's set, skip
the ``write_eof`` call and ``force_close`` the downstream response
so aiohttp drops the connection mid-body. The client's parser then
raises a ``ClientPayloadError`` / ``ServerDisconnectedError`` and
the failure is surfaced instead of silently producing garbage.

Added a regression test that spins up an upstream which calls
``force_close`` mid-response; the proxy must propagate the failure
to the client (exception on ``resp.read()``), never return a clean
body.

Flagged by sentry review on #12745 (thread 3067897364).
2026-04-11 10:41:40 +00:00
majdyz
cd9924f03e fix(copilot/sdk-proxy): address CodeRabbit follow-ups
Three follow-up findings from CodeRabbit's second-pass review:

* The forbidden-pattern scanner in ``cli_openrouter_compat_test``
  relied on a substring match against the prettified form
  `"type": "tool_reference"` (with a space). The CLI is free to
  emit compact JSON like `{"type":"tool_reference"}` which would
  slip past the scanner and false-pass the reproduction test.
  Replaced the substring check with a JSON walker that catches any
  dict with `type == "tool_reference"` regardless of serialisation,
  with a whitespace-tolerant regex fallback for malformed bodies.
  Added two regression tests (compact form, malformed fallback).

* The timeout path in ``_run_cli_against_fake_server`` called
  ``proc.kill()`` and returned immediately, leaving an unreaped
  subprocess until event-loop shutdown. Reap it with a 5-second
  bounded ``proc.communicate()`` wait after the kill.

* ``test_proxy_returns_502_on_upstream_failure`` swallowed
  ``aiohttp.ClientError`` / ``asyncio.TimeoutError`` on the outer
  ``client.post``. That outer call talks to the *proxy* on
  localhost — not the dead upstream — so any exception there
  indicates a proxy crash and must fail the test, not be caught.
  Removed the except block and bumped the client timeout to 10s to
  give the proxy room to return its 502. Also asserts the response
  body contains the generic "upstream error" text so a regression
  that replaces the 502 with a different status is caught.
2026-04-11 10:30:31 +00:00
majdyz
5d6cf91642 fix(copilot): handle bool default in compat-proxy env validator
The ``get_claude_agent_use_compat_proxy`` validator added in the
previous commit used ``if v is None`` to decide when to fall back to
the unprefixed env var. But unlike ``claude_agent_cli_path`` (which
defaults to ``None``), this field has ``default=False``. Pydantic-
settings passes the default bool into a ``mode="before"`` validator
when no explicit value is provided, so the ``is None`` branch never
fired and the unprefixed ``CLAUDE_AGENT_USE_COMPAT_PROXY`` env var
was silently ignored.

Switch to checking the raw process env directly: if the prefixed
``CHAT_CLAUDE_AGENT_USE_COMPAT_PROXY`` is set we trust Pydantic's
parsed value (which preserves any explicit ``false``), otherwise we
return the unprefixed env var's raw string so Pydantic's usual
truthy/falsy coercion handles it.

Added a new ``TestClaudeAgentUseCompatProxyEnvFallback`` class
covering both env-var names, the prefixed-wins-over-unprefixed
precedence (including the ``CHAT_...=false`` + unprefixed ``=true``
case), and the default. Also added the mirror tests for
``claude_agent_cli_path`` and included the new env var names in the
``_ENV_VARS_TO_CLEAR`` fixture so existing tests don't leak.

Flagged by sentry review on #12745 (thread 3067888297).
2026-04-11 10:26:17 +00:00
majdyz
0554a0ae35 fix(copilot/sdk-proxy): address PR review — RFC 7230 hop-by-hop,
timeouts, cancellation, provider gating

Addresses all seven review threads on #12745 (coderabbit + sentry)
in a single commit because they overlap in the same file cluster:

config.py
---------
* ``claude_agent_use_compat_proxy`` gains a ``field_validator`` that
  reads the unprefixed ``CLAUDE_AGENT_USE_COMPAT_PROXY`` in addition
  to the Pydantic-prefixed ``CHAT_`` form, matching the same dual-name
  pattern already used by ``api_key`` / ``base_url`` /
  ``claude_agent_cli_path`` and keeping parity with the docstring and
  the PR description. Without this the operator-facing env var was
  silently ignored because of ``env_prefix = "CHAT_"``.

openrouter_compat_proxy.py
--------------------------
* ``_HOP_BY_HOP_HEADERS`` now includes the canonical ``trailer``
  (singular per RFC 7230 §4.4) alongside the plural ``trailers``;
  ``clean_request_headers`` additionally drops every header whose
  name is listed in the incoming ``Connection`` field value (§6.1
  extension hop-by-hop headers), case-insensitively — previously
  extension hop-by-hop headers could leak upstream.

* ``strip_tool_reference_blocks`` now *removes* dict-valued
  ``tool_reference`` children from their parent dict instead of
  rewriting them to ``null``; the stated "strip anywhere" semantics
  were broken on nested dict assignments and still produced
  schema-invalid payloads upstream. Genuine ``None`` children on
  non-dict values are still preserved.

* ``_handle`` upstream-call error handler now catches
  ``asyncio.TimeoutError`` alongside ``aiohttp.ClientError`` —
  ``aiohttp.ClientTimeout`` raises ``asyncio.TimeoutError`` (not
  ``aiohttp.ClientError``), so hung upstreams used to escape as a
  generic 500 instead of the documented 502.

* Streaming-response handler no longer suppresses
  ``asyncio.CancelledError``. It's now split into its own except
  branch, releases the upstream body, and re-raises so cooperative
  task cancellation works as intended (cancellation while mid-stream
  was previously being caught alongside ``ClientError`` and silently
  swallowed, leading to hung request handlers on client disconnects
  / shutdowns).

* ``start()`` wraps the ``runner.setup() / site.start()`` sequence in
  try/except that tears down both the client session and the
  (partially-initialised) runner on any exception, so failed startups
  never leak resources. The attributes are only published to the
  instance after the full chain succeeds.

service.py
----------
* The compat-proxy startup is now gated on there actually being an
  Anthropic-compatible upstream to forward to. Previously the code
  fell back to ``OPENROUTER_BASE_URL`` unconditionally, which would
  silently re-route direct-Anthropic / Claude Code subscription
  sessions through OpenRouter and break auth. The new gate is:
  explicit ``ANTHROPIC_BASE_URL`` in ``sdk_env`` or the process env,
  OR ``ChatConfig.openrouter_active`` (OpenRouter is configured as
  the session's routing provider). When neither holds we log a
  warning and skip proxy startup — the feature is opt-in and named
  "OpenRouter compatibility", so no-oping direct-Anthropic sessions
  is the safe default. The success log line also drops the upstream
  URL to match the taint-analysis guidance already applied to
  ``openrouter_compat_proxy.start``.

Tests
-----
* Added regression tests for the dict-valued tool_reference fix, the
  Connection-listed header stripping (with case-insensitive matching),
  and an end-to-end 502-on-upstream-timeout test (fake upstream that
  sleeps longer than the proxy's request timeout). The hop-by-hop
  completeness test now also pins ``trailer`` / ``trailers``.
2026-04-11 10:18:50 +00:00
majdyz
fed728e546 fix(copilot/sdk-proxy): drop upstream from log message entirely
Previous fix logged the parsed netloc instead of the full URL, but
CodeQL's `py/clear-text-logging-sensitive-data` taint analysis still
traces the value through `urlparse(target_base_url).netloc` and
flags the log call. Address by dropping the upstream component from
the log entirely — only the local bind port is logged. The upstream
endpoint is discoverable from `ChatConfig` and exposed via the
`target_base_url` property for callers that need it.
2026-04-11 10:13:49 +00:00
majdyz
93f27ffdf6 fix(copilot/sdk-proxy): address CodeQL findings + isort drift
CodeQL flagged two issues in the new compat proxy:

1. `py/clear-text-logging-sensitive-data` (high) — logging
   `self._target_base_url` could leak credentials if a future caller
   passed a URL containing them. Switched to logging only the host
   component (and the local 127.0.0.1 port) so even an
   accidentally-credentialled base URL stays out of logs.

2. `py/stack-trace-exposure` (medium) — returning the upstream
   exception text in the 502 response body could leak internal
   hostnames or stack frames to the client. Changed to a generic
   "upstream error" string; the detailed exception is still logged
   server-side.

Also fixes an isort sorting drift in the test file (private
underscore-prefixed names must sort before public names — local
isort accepted the order, CI's isort did not).
2026-04-11 10:13:48 +00:00
majdyz
0f00972efc feat(copilot): in-process OpenRouter compat proxy for newer Claude SDK
The Claude Code CLI in any `claude-agent-sdk` version above 0.1.47
sends the `context-management-2025-06-27` beta header / body field
that OpenRouter rejects with HTTP 400. This blocks us from upgrading
to take features we want (`exclude_dynamic_sections` cross-user prompt
caching in 0.1.57, `AssistantMessage.usage` per-turn token tracking
in 0.1.49, the MCP large-tool-result truncation fix in 0.1.55, etc).
Tracked upstream at anthropics/claude-agent-sdk-python#789, no fix
released yet.

This commit adds an in-process HTTP middleware that lets the latest
SDK / CLI talk to OpenRouter unchanged. The proxy:

* listens on `127.0.0.1:RANDOM_PORT`,
* receives every CLI request that would normally go to
  `ANTHROPIC_BASE_URL`,
* strips `tool_reference` content blocks (the original 0.1.46+
  regression — defensive, in case the CLI 2.1.70 proxy detection
  ever regresses) and `context-management-2025-06-27` from both the
  request body's `betas` array and the `anthropic-beta` header,
* forwards the cleaned request upstream and streams the response
  back unchanged.

Wired via `ChatConfig.claude_agent_use_compat_proxy` (default
`False`, opt-in). When the flag is on, the SDK service starts a
proxy per session, injects its local URL into the spawned CLI
subprocess `env` as `ANTHROPIC_BASE_URL`, and tears it down in the
session's `finally` block.

The proxy is intentionally orthogonal to the existing
`claude_agent_cli_path` override:

* `cli_path`  picks **which** CLI binary we run.
* compat proxy rewrites **whatever the chosen binary sends**.

Both can be combined or used independently.

Tests cover:

* the pure stripping helpers (`strip_tool_reference_blocks`,
  `strip_forbidden_betas_from_body`,
  `strip_forbidden_anthropic_beta_header`,
  `clean_request_body_bytes`, `clean_request_headers`) including
  edge cases like empty input, non-JSON bodies, and the
  hop-by-hop header set,
* end-to-end behaviour against a fake upstream server: stripping
  the `tool_reference` block in nested `tool_result.content`,
  rewriting the `anthropic-beta` header,
  removing the forbidden token from the body `betas` array,
  passing through clean requests unchanged, and returning a clear
  502 on upstream failure (no infinite hang).
2026-04-11 10:13:48 +00:00
majdyz
a6e306d28a fix(copilot): accept unprefixed CLAUDE_AGENT_CLI_PATH in config
The new `claude_agent_cli_path` field inherited the `CHAT_` Pydantic
prefix from `ChatConfig`, so the documented `CLAUDE_AGENT_CLI_PATH`
env var was silently ignored — operators following the PR description
or the field docstring would set the unprefixed form and the config
would fall back to the bundled CLI.

Add a `field_validator` that reads `CHAT_CLAUDE_AGENT_CLI_PATH` first
and falls back to the unprefixed `CLAUDE_AGENT_CLI_PATH`, matching the
same pattern already used by `api_key` and `base_url`. The test helper
`_resolve_cli_path` in `cli_openrouter_compat_test.py` mirrors the
same two-name lookup so the reproduction test picks up the override
regardless of which form is set, and a new test covers the prefixed
variant explicitly.

Flagged by sentry review on #12741 (thread IDs 3067725580 and
3067768817) as two instances of the same bug.
2026-04-11 10:11:47 +00:00
majdyz
47852cfdf5 fix(frontend/builder): add navigation race guard to tool-call detection effect
Mirrors the existing `skipNextParseRef` guard on the parse-actions effect.
When `flowID` changes, the reset effect clears `processedToolCallsRef` and
`lastScannedToolCallIndexRef` and queues `setMessages([])`, but the cleared
messages are not yet committed when the tool-call detection effect runs in
the same effect cycle. Without the skip, the effect would re-scan the
previous graph's messages from index 0 and re-fire `onGraphEdited` /
`setQueryStates(flowExecutionID)` for tool calls belonging to the old
graph — triggering a stray `refetchGraph()` on the new graph or
auto-following a stale execution.

Uses a separate `skipNextToolScanRef` so each effect consumes its own
flag independently; a shared ref would let whichever effect ran first
clear the guard before the other could skip.
2026-04-11 10:08:08 +00:00
majdyz
d6f0fcb052 test(copilot/sdk-compat): unit-test the forbidden-pattern scanner
Add direct unit tests for `_scan_request_for_forbidden_patterns` and
`_resolve_cli_path` so the helper logic stays exercised even on CI
runs where the slow end-to-end CLI subprocess test can't capture a
request (sandboxed runner, missing CLI binary, etc).

Brings codecov/patch coverage above the 80% gate. No production
code changes — tests only.
2026-04-11 07:57:04 +00:00
majdyz
feb247d56e chore(backend): drop stray blank line in platform_cost_test.py
Same pre-existing dev-branch lint issue from PR #12739 — black would
reformat this file (extra blank line between two test classes), which
fails the `lint` CI job for any PR branched from current dev.
2026-04-11 07:10:55 +00:00
majdyz
fdb3590693 chore(copilot): add SDK CLI override + OpenRouter compat regression tests
We've been pinned at `claude-agent-sdk==0.1.45` (bundled CLI 2.1.63)
since PR #12294 because every version above introduces a 400 against
OpenRouter. There are two stacked regressions today:

1. CLI 2.1.69 (= SDK 0.1.46) added a `tool_reference` content block in
   `tool_result.content` that OpenRouter's stricter Zod validation
   rejects. CLI 2.1.70 added a proxy-detection workaround but our
   subsequent attempts at 0.1.55 and 0.1.56 still failed.
2. A newer regression — the `context-management-2025-06-27` beta
   header — appears in some CLI version after 2.1.91. Tracked upstream
   at anthropics/claude-agent-sdk-python#789, still open with no fix.

This commit doesn't actually upgrade the SDK — it adds the
infrastructure we need to upgrade safely *when* upstream lands a fix
or when we identify a known-good newer CLI version via bisection:

* `ChatConfig.claude_agent_cli_path` (env: `CLAUDE_AGENT_CLI_PATH`)
  threads through to `ClaudeAgentOptions(cli_path=...)` so we can
  decouple the Python SDK API surface from the CLI binary version.
  `_prewarm_cli` in the CoPilotExecutor honours the same override.

* `test_bundled_cli_version_is_known_good_against_openrouter` pins
  the bundled CLI to a known-good set (`{"2.1.63"}` today). Any
  `claude-agent-sdk` bump that changes the bundled CLI will fail this
  test loudly with a pointer to PR #12294 and issue #789, instead of
  silently re-breaking production.

* `test_sdk_exposes_cli_path_option` is a forward-compat sentinel that
  fails fast if upstream removes the `cli_path` option we depend on
  for the override.

* `cli_openrouter_compat_test.py` is the actual reproduction test:
  spawns the bundled (or `CLAUDE_AGENT_CLI_PATH`-overridden) CLI
  against an in-process aiohttp server pretending to be the Anthropic
  Messages API, captures every request body the CLI sends, and
  asserts that none of them contain the two known forbidden patterns
  (`"type": "tool_reference"` content blocks or
  `"context-management-2025-06-27"` in body or `anthropic-beta`
  header). The fake server returns a minimal valid streamed response
  so the CLI doesn't error out before we can inspect what it sent.
  No OpenRouter API key required — the test reproduces the *mechanism*
  rather than the symptom, so it's deterministic and free to run in CI.

Workflow for verifying a candidate upgrade going forward: bump the
SDK in `pyproject.toml`, push the commit, and watch the CI run for
both tests in `sdk_compat_test.py` and `cli_openrouter_compat_test.py`.
A clean run on both means it's safe to add the new bundled CLI version
to `_KNOWN_GOOD_BUNDLED_CLI_VERSIONS` and merge.
2026-04-11 07:05:05 +00:00
majdyz
83fc444a3d fix(frontend/builder): re-add navigation race guard for parsedActions
When the user navigates between graphs, the flowID-reset effect resets
`lastParsedMessageIndexRef` and the parsed-actions cache, then queues
`setMessages([])`. The parse-actions effect runs in the same effect
cycle — *before* the queued state updates are committed — so its
`messages` closure still belongs to the previous graph. With the index
reset to -1 and the cache empty, it would re-scan those stale messages
from index 0 and briefly flash the previous graph's actions in the new
panel.

A previous guard (`277c19642`) was lost when commit `1935137c1` (the
DefaultChatTransport memoization fix) accidentally dropped the
`if (currentFlowIDRef.current !== flowID) return;` line. That guard
was actually a no-op because `currentFlowIDRef` is updated by an earlier
effect in the same cycle, so the check never fired — the bug was masked
in practice but came back into view when sentry re-flagged it.

Replace the removed line with a one-shot `skipNextParseRef` flag that
the cleanup effect sets only on *actual* navigation (not initial mount,
detected via `prevFlowIDRef`). The parse-actions effect skips one pass
when the flag is set, then clears it. This correctly handles:

  - Initial mount: no skip (flag stays false), first run parses normally.
  - Navigation: skip one pass; next render arrives with fresh messages
    from useChat's re-key and parses them correctly.
  - Same-flowID re-render: cleanup doesn't fire, no skip, normal parse.

New regression test reproduces the navigation race in the parsed-actions
integration suite.

Sentry bug prediction: PRRT_kwDOJKSTjM56RVeU (severity HIGH).
2026-04-11 05:18:25 +00:00
majdyz
693c616bf5 fix(util/cache): properly distinguish missing entries from cached None
The @cached decorator could not differentiate "no entry" from "entry is
None" — both `_get_from_memory` and `_get_from_redis` returned `None`
for misses, and the wrappers checked `result is not None` to decide
whether to recompute. Functions that returned `None` as a valid value
were therefore re-executed on every call, defeating the cache and (for
shared_cache=False) potentially causing per-pod thundering herd against
upstream APIs.

Fix:
- Use a module-level `_MISSING = object()` sentinel for "no entry".
- Wrappers now check `result is not _MISSING` so cached `None` is
  returned correctly.
- Add a `cache_none: bool = True` parameter so callers that *want* the
  retry-on-None behavior (e.g. external API calls returning `None` to
  signal a transient error) can explicitly opt out via `cache_none=False`.
- `_get_stripe_price_amount` opts out: returning None on a Stripe error
  must not poison the 5-minute cache window. Updated its docstring to
  describe the actual behavior.

New tests cover both default (None is cached) and `cache_none=False`
(None is not stored, next call retries) for sync, async, and shared
cache paths.

Sentry bug prediction: PRRT_kwDOJKSTjM56RTEu (severity HIGH).
2026-04-11 05:03:54 +00:00
majdyz
2a6b65fd7b fix(executor): notify low balance after extra-iteration charges
The `_on_node_execution` path called `charge_extra_iterations` and
ignored the returned `remaining_balance`, so users were never notified
when their balance crossed the low-balance threshold via these
post-hoc per-LLM-call charges. `charge_node_usage` already does the
right thing — mirror that pattern here so all charging paths route
through `_handle_low_balance` consistently.

Sentry bug prediction: PRRT_kwDOJKSTjM56Mibw (severity MEDIUM).
2026-04-11 04:47:56 +00:00
majdyz
2611143c67 fix(backend/copilot): simplify credential guidance — run_agent handles creds inline as part of run flow 2026-04-11 09:31:39 +07:00
majdyz
8b60bd5e78 fix(backend/copilot): remove duplicate credential guidance from tool descriptions, fix guide to distinguish missing vs valid-update case 2026-04-11 09:28:52 +07:00
majdyz
6f7bf90769 fix(backend): harden URL validator and add adversarial redirect tests
Reject URLs containing '@', backslashes, or control characters before
urlparse to prevent auth-trick and backslash-normalisation attacks.
Add parametrized tests covering 11 adversarial inputs + valid cases.
2026-04-11 09:27:29 +07:00
majdyz
1935137c10 fix(frontend): memoize DefaultChatTransport to prevent mid-stream resets
Wraps the DefaultChatTransport instantiation in useMemo([sessionId]) so
the same transport object is reused across renders. Without memoisation,
each streaming chunk (which triggers a re-render) created a new transport
instance, resetting useChat's internal Chat state mid-stream. Matches the
pattern already used in useCopilotStream.ts.
2026-04-11 09:25:37 +07:00
majdyz
ce57601305 fix(frontend): fix TypeScript errors in SubscriptionTierSection and its test
- Dialog controlled set callback: use explicit if-block to avoid
  returning 'false | void' (TS2322)
- Test redirect test: use vi.stubGlobal to replace window.location with
  a plain object (Proxy on jsdom Location breaks private-field access)
2026-04-11 09:24:35 +07:00
majdyz
d81bbdb870 fix(backend): avoid caching Stripe error fallback in _get_stripe_price_amount
Return None on StripeError instead of 0 so the @cached decorator
(which skips caching None) does not persist the error state for 5 min.
Added test to verify the None→0 fallback path in get_subscription_status.
2026-04-11 09:14:24 +07:00
Zamil Majdy
e79214f3dd refactor(frontend/builder): remove useMemo violations + add incremental tool-call scanning
Per AGENTS.md conventions, useMemo/useCallback should not be used unless
asked to optimise. Remove useMemo from ActionList (nodeMap), MessageList
(visibleMessages filter), and useBuilderChatPanel (transport).

Also add lastScannedToolCallIndexRef to make tool-call detection O(new
messages) matching the action parser's incremental approach.
2026-04-11 09:08:15 +07:00
majdyz
7f6163b180 fix(platform): address final PR review comments on subscription billing
- Replace __legacy__ Dialog import with molecules/Dialog in SubscriptionTierSection
- Update test mock to match new Dialog API (controlled pattern)
- Guard still_has_active_sub against empty new_sub_id in sync_subscription_from_stripe
- Move urlparse import from inside _validate_checkout_redirect_url to module level
2026-04-11 09:07:31 +07:00
majdyz
2057b4597e test(frontend): add Vitest+RTL integration tests for SubscriptionTierSection
Covers: tier card rendering, Current badge, cost display, upgrade/downgrade
flow (with Stripe redirect), confirmation dialog, error handling, ENTERPRISE
user messaging, and success param handling.
2026-04-11 09:00:45 +07:00
majdyz
5bb7027f89 fix(platform): address remaining PR review comments on subscription billing
Backend:
- Cache stripe.Price.retrieve with 5-min TTL via _get_stripe_price_amount
  to avoid 200-600ms Stripe round-trip on every GET /credits/subscription
- Use SubscriptionTier enum .value for FREE/ENTERPRISE in tier_costs dict
  for consistency (instead of hardcoded strings)
- Rename misleading test names: "defaults_to_FREE" → "preserves_current_tier"
  to reflect actual behaviour (unknown price IDs preserve tier, not reset)
- Update subscription_routes_test to mock _get_stripe_price_amount instead
  of stripe.Price.retrieve directly, avoiding cached-result interference

Frontend:
- Handle ?subscription=success return from Stripe Checkout: refetch + toast
- Add downgrade confirmation Dialog before cancelling paid subscription
- Handle ENTERPRISE tier: render dedicated admin-managed plan card, not the
  FREE/PRO/BUSINESS tier cards (which would show no "Current" badge)
- Track pendingTier (via variables) so only the clicked button shows "Updating..."
- Show "Pricing available soon" for paid tiers with cost=0 (unconfigured LD flags)
  instead of misleading "Free"
- Move tierError state into the hook, set via changeTier internally
- Move TIER_ORDER constant to module scope (was magic array inside render body)
- Add aria-current="true" to active tier card for screen reader accessibility
- Add role="alert" to all error paragraph elements
- Improve tier descriptions with concrete capacity values
2026-04-11 08:57:34 +07:00
majdyz
958344562b merge(frontend/builder): resolve conflicts from PR #12726 dev merge
Resolve merge conflicts between builder-chat-panel feature and the
per-model cost breakdown PR (#12726):

- Flow.tsx: keep ErrorBoundary wrapper from our branch
- BuilderChatPanel.tsx + useBuilderChatPanel.ts: keep our latest refactor
- platform_cost_test.py: use Prisma ORM style for export test (theirs)
- useBuilderChatPanel.test.ts + BuilderChatPanel.test.tsx: keep latest tests
2026-04-11 08:54:55 +07:00
majdyz
329a034ebe merge(platform): merge latest dev into feat/subscription-tier-billing 2026-04-11 08:50:35 +07:00
majdyz
6b390d6677 fix(backend/copilot): apply session_msg_ceiling to no-resume compression fallback
The no-resume fallback in _build_query_message used raw msg_count (> 1) to
detect multi-message history and session.messages[:-1] for the compression
slice. After a turn-start drain appends pending messages, msg_count is inflated
and the fallback fires on what should be a fresh first turn, placing the current
user message into the history context and delivering a confusing split prompt to
the model.

Apply session_msg_ceiling to both branches:
- elif condition: effective_count > 1 instead of msg_count > 1
- compression slice: session.messages[:effective_count - 1] instead of [:-1]

With _pre_drain_msg_count=1 on a first turn with drained pending messages,
effective_count=1 so the fallback is correctly skipped and current_message
(which already contains both the original and pending text) is returned as-is.

Adds regression test covering the spurious-fallback scenario.
2026-04-11 08:45:54 +07:00
majdyz
519226406d fix(backend/copilot): hide structural errors only when all node errors are credential-related
When GraphValidationError contains a mix of credential and structural errors
(e.g., missing inputs), the previous `any()` check would surface the credentials
setup card and silently discard the structural errors. The user would fix
credentials, re-run, and only then see the structural error — a confusing
two-step failure.

Change the guard from `any(...)` to `all(...)` so the credentials setup card
is only shown when every error in node_errors is credential-related. Mixed
errors fall through to the plain ErrorResponse path.

Adds a regression test for the mixed-error case.
2026-04-11 08:29:52 +07:00
majdyz
1d05b06e43 fix(backend/copilot): prevent pending message duplication in stale-transcript gap
When use_resume=True and the transcript is stale, _build_query_message computes
a gap slice from session.messages[transcript_msg_count:-1].  Pending messages
drained at turn start are appended to session.messages AND concatenated into
current_message, so without the ceiling they appear in both gap_context and
current_message.

Capture _pre_drain_msg_count before drain_pending_messages() and pass it as
session_msg_ceiling to _build_query_message.  The gap slice is now bounded at
the pre-drain count, preventing pending messages from leaking into the gap.

Adds two regression tests in query_builder_test.py.
2026-04-11 08:25:14 +07:00
majdyz
cc96ed131b fix(backend/copilot): show all creds as missing in race path; add tool guardrails
- _build_setup_requirements_from_validation_error now passes None to
  build_missing_credentials_from_graph so all credential fields show up
  as missing_credentials in the race scenario (prev-matched creds
  became invalid between prereq check and executor call).
- Add test_run_agent_execution_credential_race_returns_setup_card to
  mirror the existing schedule-path race test for the run path.
- Add credential-setup guardrail text to create_agent and edit_agent
  tool descriptions so the LLM doesn't call them for credential setup.
2026-04-11 08:13:12 +07:00
majdyz
c58176365f fix(backend/copilot): use atomic Lua EVAL for pending call-frequency counter
Replace separate INCR + EXPIRE with a single Lua EVAL so the rate-limit
key can never be orphaned without a TTL. If the process died between the
two commands the key would persist indefinitely, permanently locking out
the user after hitting the 30-push limit.

Fixes sentry bug report on routes.py:1153.
2026-04-11 08:01:15 +07:00
majdyz
6f9f2c72db test(util/service): use Protocol cast instead of type: ignore
Address CodeRabbit feedback on the new GraphValidationError round-trip
tests: the file already has a ``_SupportsGetReturn`` Protocol pattern
and coding guidelines forbid ``# type: ignore`` suppressors.

Add ``_SupportsHandleCallMethodResponse`` Protocol alongside the
existing one, cast the test client to it, and drop the two
``# type: ignore[attr-defined]`` suppressions added for the new tests.
Pre-existing suppressions on older tests are left untouched (separate
concern, out of scope for this PR).
2026-04-10 23:59:36 +00:00
majdyz
24e1e37ebe fix(util/service): preserve GraphValidationError.node_errors over RPC
``GraphValidationError`` carries a structured ``node_errors`` mapping
in addition to its top-level message, but RPC serialisation only
copied ``exc.args`` — so by the time the exception reached the
copilot client (for the schedule path, which goes via
``get_scheduler_client``), ``node_errors`` was ``{}``.

That broke the credential-race fallback introduced in the earlier
commits in this PR: the helper saw an empty ``node_errors``,
concluded it wasn't a credential error, and fell back to the generic
``ErrorResponse`` — exactly the symptom this PR is trying to fix.

Fix:
- Add an optional ``extras`` dict to ``RemoteCallError`` for
  exception types that carry structured attributes beyond ``args``.
- In ``_handle_internal_http_error``, when the exception is a
  ``GraphValidationError``, pack ``node_errors`` into ``extras``.
- In ``_handle_call_method_response``, when reconstructing a
  ``GraphValidationError``, read ``extras.node_errors`` and pass it
  to the constructor.
- Add two tests for the round-trip (with and without extras), the
  latter to ensure backwards compatibility with old server responses.
2026-04-10 23:52:10 +00:00
majdyz
398247fb60 fix(copilot): filter matched credentials & share executor helper
Address review feedback on the credential-race setup card:

- Thread `graph_credentials` through
  `_build_setup_requirements_from_validation_error` so
  `missing_credentials` only lists fields the user hasn't connected
  yet. Previously it was computed with `None`, so the inline card
  showed every connected credential as missing during a race —
  the opposite of the UX fix.
- Promote the credential-error-string matcher to a public helper
  `backend.executor.utils.is_credential_validation_error_message` and
  reuse it from both the dry-run path in
  `_construct_starting_node_execution_input` and from the copilot
  tool, so adding a new credential error string only requires touching
  one file.
- Make the setup-card message action-neutral ("try again" instead of
  "try scheduling again") — the helper is shared by the run and
  schedule paths and the previous wording misled run-path users.
- Extend tests: cover the shared helper, add a test that verifies
  connected credentials are filtered out of `missing_credentials`,
  and assert the message is path-neutral.
2026-04-10 23:44:59 +00:00
majdyz
c469c347d5 refactor(backend/copilot): move credential routing rule into agent guide
Per review feedback: spamming the "do NOT use this tool to set up
credentials" warning across `create_agent` / `edit_agent` / `run_agent`
descriptions duplicates guidance that belongs in one place.

Move the "use run_agent / connect_integration, never the Builder" rule
into the existing **Credentials** key rule in `agent_generation_guide.md`
(loaded on demand by `get_agent_building_guide`), and revert the three
tool descriptions to their previous wording.
2026-04-10 23:04:41 +00:00
majdyz
a7d06854e3 feat(copilot): add per-user call-frequency rate limit to pending endpoint
The token-budget check guards against over-spending but does not prevent
rapid-fire pushes from a client with a large budget.  Add a Redis
INCR + EXPIRE sliding-window counter (30 calls per 60-second window per
user) to cap call frequency independently of token consumption.

Returns HTTP 429 with "Too many pending messages" when exceeded.
Fails open (Redis unavailable → allows request).

Adds test for the new 429 path.

Addresses autogpt-pr-reviewer "Should Fix: per-request rate limit".
2026-04-11 00:42:25 +07:00
majdyz
2bd1bbb7d1 chore(backend): drop stray blank line in platform_cost_test.py
Black would reformat this file (extra blank line between two test
classes), which fails the `lint` CI job for any PR branched from the
current dev.  Tiny drive-by fix to keep this PR's CI green.
2026-04-10 17:42:22 +00:00
majdyz
9bfcdf3f11 test(copilot): add combined-fields test for format_pending_as_user_message
Verify that content + context (url + content) + file_ids all appear in
the formatted output when all fields are present simultaneously.

Addresses autogpt-pr-reviewer 'format_pending_as_user_message never
tested with all fields simultaneously'.
2026-04-11 00:35:27 +07:00
majdyz
18c75beb7a nit(copilot): name pub/sub notify payload constant
Replace magic string "1" in redis.publish() with named constant
_NOTIFY_PAYLOAD for self-documentation.

Addresses autogpt-pr-reviewer nit.
2026-04-11 00:33:49 +07:00
majdyz
9da0dd111f refactor(copilot): extract shared file-ID sanitization helper
Extract `_resolve_workspace_files(user_id, file_ids)` helper from the
duplicated UUID-filter + workspace-DB-lookup logic in both
`stream_chat_post` and `queue_pending_message`.

Both endpoints now call the single helper; callers map the returned
`list[UserWorkspaceFile]` to IDs or file-description strings as before.

Also removes the redundant `if user_id:` guard from `stream_chat_post`'s
file-ID block — `Security(auth.get_user_id)` guarantees a non-empty string.

Addresses autogpt-pr-reviewer "Should Fix: Duplicated file-ID sanitization"
and coderabbitai nit on the if user_id guard.
2026-04-11 00:31:03 +07:00
majdyz
3ef24b3234 refactor(copilot): narrow exception handling and type context field
- Replace broad `except Exception` with `except (json.JSONDecodeError,
  ValidationError, TypeError, ValueError)` in drain_pending_messages so
  unexpected non-data errors propagate instead of being silently swallowed
- Introduce `PendingMessageContext` Pydantic model to replace the raw
  `dict[str, str]` for the context field, making the url/content contract
  explicit and enabling typed attribute access instead of .get() calls
- Update routes.py to construct PendingMessageContext from the validated
  request dict before passing to PendingMessage
- Update tests to use PendingMessageContext directly

Addresses coderabbitai review comments.
2026-04-11 00:27:15 +07:00
majdyz
90b9c2ab46 fix(backend/executor): skip execution tier charge for nested tool calls
execution_usage_cost(0) incorrectly charges 1 credit because 0 % threshold
== 0. charge_node_usage passes 0 to _charge_usage to signal "no tier
increment", but the modulo check fires at 0. Fix: skip execution_usage_cost
entirely when execution_count == 0, preserving the intent that nested tool
executions don't count against execution tiers.
2026-04-11 00:13:59 +07:00
majdyz
62f3ed79be style(backend): fix Black formatting in platform_cost_test.py
Black detected double blank lines between class definitions in
platform_cost_test.py (pulled from dev base). Normalise to a single
blank line so the CI merge-commit lint check passes.
2026-04-11 00:12:16 +07:00
majdyz
fd658a07b5 fix(backend/copilot): keep credential setup inline when scheduling
The credential gate in `_check_prerequisites` only fires before the
scheduler/executor call. If credentials are deleted (or otherwise drift)
between the prereq check and the actual call, the scheduler raises a
generic `GraphValidationError` and the user sees a plain error string —
in the worst case the LLM falls back to `create_agent`/`edit_agent`,
which return an `AgentSavedResponse` linking the user to the Builder.

Catch `GraphValidationError` in `_run_agent` and `_schedule_agent`,
detect credential-flavoured node errors, and rebuild the inline
`SetupRequirementsResponse` so the credential setup card always appears
without the user leaving the chat.

Also updates the `run_agent` / `create_agent` / `edit_agent` tool
descriptions to explicitly tell the LLM to use `run_agent` (or
`connect_integration`) for credential setup — never the Builder.
2026-04-10 17:12:05 +00:00
majdyz
d10d14ae74 test(copilot): add coverage for pending-message endpoint and URL test
- Add 11 tests for QueuePendingMessageRequest validation and the
  POST /sessions/{id}/messages/pending endpoint covering:
  - 202 happy path
  - 422 on empty/oversized message, context.url > 2KB, context.content > 32KB, >20 file_ids
  - 404 on unknown session
  - 429 on rate limit exceeded
  - file_ids scoped to caller's workspace
- Fix CodeQL false-positive: replace broad url-in-content assertion
  with exact [Page URL: url] substring check in pending_messages_test
2026-04-11 00:10:20 +07:00
majdyz
5e8345e5ee fix(copilot): fix CodeQL false-positive in pending_messages_test
Replace broad `url in content` assertion with exact `[Page URL: url]`
substring check so CodeQL does not flag it as Incomplete URL Substring
Sanitization.
2026-04-11 00:06:24 +07:00
majdyz
bb071a9c88 refactor(frontend/builder): address review comments in BuilderChatPanel
- Remove useCallback from handleApplyAction (violates AGENTS.md)
- Import TEXTAREA_MAX_LENGTH from PanelInput instead of duplicating constant
- Remove dead @tanstack/react-query mock and associated invalidateQueries test
2026-04-11 00:03:10 +07:00
majdyz
c327d4f2a8 merge(dev): pull latest dev + fix platform_cost_test.py black formatting
Merge origin/dev to pick up recent changes. Also fix an extra blank line
in backend/data/platform_cost_test.py that black (via Python 3.12 CI)
flags as a lint error in the merge commit.
2026-04-11 00:02:52 +07:00
majdyz
54450def6b fix(platform): guard Stripe webhook against empty-secret HMAC bypass
An empty STRIPE_WEBHOOK_SECRET (the default) allows an attacker to
compute a valid HMAC-SHA256 signature over the same key and forge any
webhook event (customer.subscription.created, etc.), escalating any
user to an arbitrary subscription tier without paying.

Fix: return 503 immediately when stripe_webhook_secret is unset rather
than proceeding to signature verification. Also add run_in_threadpool
to get_stripe_customer_id and remove the duplicate trialing-sub test.

Merges origin/feat/subscription-tier-billing which had the open-redirect
guard, blocking-IO fix, and idempotency/ENTERPRISE guard.

Test added: test_stripe_webhook_unconfigured_secret_returns_503
2026-04-11 00:00:50 +07:00
majdyz
a7d97dacf3 fix(copilot): address review comments on pending-messages PR
- Use _pre_drain_msg_count for transcript load gate (len > 1 check)
  to avoid spurious transcript load on first turn with pending messages
- Use _pre_drain_msg_count for Graphiti warm context gate to prevent
  warm context skip when pending messages are drained at first turn
- Add context.url/content length validators to QueuePendingMessageRequest
  to prevent LLM context-window stuffing (2K url, 32K content caps)
- Rename underscore-prefixed active variables (_pm, _content, _pt)
  to conventional names (pm, content, pt) per Python convention
2026-04-11 00:00:07 +07:00
majdyz
7e7b3c42cb style(backend/executor): replace deprecated get_event_loop().run_in_executor with asyncio.to_thread
asyncio.get_event_loop() is deprecated in Python 3.10+; asyncio.to_thread
is the idiomatic replacement and consistent with the rest of manager.py.
2026-04-10 23:56:24 +07:00
majdyz
8ad5bf03a7 fix(platform): critical security fixes for Stripe webhook + async IO
- Guard stripe_webhook: return 503 when STRIPE_WEBHOOK_SECRET is empty.
  An empty secret allows HMAC forgery (attacker computes a valid sig over
  the same key), so we reject all webhook calls when unconfigured.
- Suppress raw Stripe error from 502 cancel response; log server-side instead.
- Wrap all blocking Stripe SDK calls in run_in_threadpool: Customer.create,
  Subscription.list, Subscription.cancel, checkout.Session.create.
- cancel_stripe_subscription now also cancels 'trialing' subscriptions
  (previously only 'active'), preventing billing after a FREE downgrade.
- session.url None now raises ValueError instead of returning empty string.
- Add tests: webhook 503 on missing secret, trialing-sub cancellation.
2026-04-10 23:55:18 +07:00
majdyz
6523dce30c fix(backend/orchestrator): use AsyncMock for charge_node_usage in test
charge_node_usage is async and is directly awaited; using MagicMock
caused a TypeError that was silently swallowed by the outer except
Exception block, meaning the billing assertion passed for the wrong
reason (mock called but await failed, so no billing actually ran).
2026-04-10 23:53:52 +07:00
majdyz
39e89b50a7 fix(copilot): address remaining CI failures on pending-messages
1. SDK pyright: the inner ``_fetch_transcript`` closure captured
   ``session`` which pyright couldn't narrow to non-None (the outer
   scope casts it, but the narrowing doesn't propagate into the
   nested async function).  Added an explicit ``assert session is not
   None`` at the top of the closure.
2. Lint: re-formatted ``platform_cost_test.py`` — some pre-existing
   whitespace drift from an upstream merge was tripping Black on CI.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 16:41:55 +00:00
majdyz
f8f7df7b0a fix(copilot): address CI failures on pending-messages PR
1. SDK retry tests failing with "Event loop is closed" — the
   drain-at-start call in stream_chat_completion_sdk was reaching the
   real ``drain_pending_messages`` (which hits Redis) instead of being
   mocked.  Added a ``drain_pending_messages`` stub returning ``[]`` to
   the shared ``_make_sdk_patches`` helper so all retry-integration
   tests skip the drain path.

2. API types check failing — the new
   ``POST /sessions/{id}/messages/pending`` endpoint wasn't reflected
   in the frontend's ``openapi.json``.  Regenerated via
   ``poetry run export-api-schema --output ../frontend/src/app/api/openapi.json``
   and ``pnpm prettier --write``.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 16:34:20 +00:00
majdyz
1d0202a882 Merge branch 'feat/copilot-pending-messages' of github.com:Significant-Gravitas/AutoGPT into feat/copilot-pending-messages 2026-04-10 23:29:57 +07:00
majdyz
a4dbcf4247 fix(backend/copilot): address round-3 review — dedup, persist, guards
- Replace maybe_append_user_message with direct session.messages.append
  for pending drain in both baseline mid-loop and SDK drain-at-start:
  pending messages are atomically popped from Redis and are never
  stale-cache duplicates, so the dedup is wrong and causes
  openai_messages/transcript to diverge from the DB record
- Add immediate upsert_chat_session after SDK drain-at-start so a
  crash between drain and finally doesn't lose messages already removed
  from Redis
- Capture _pre_drain_msg_count before the baseline drain-at-start:
  use it for is_first_turn (prevents pending messages from flipping the
  flag to False on an actual first turn) and for _load_prior_transcript
  (prevents the stale-transcript check from firing on every turn that
  drains pending messages, which would block transcript upload forever)
- Remove redundant if user_id: guards in queue_pending_message — user_id
  is guaranteed non-empty by Security(auth.get_user_id); the guards made
  the rate-limit check silently optional
2026-04-10 23:29:44 +07:00
majdyz
51465fbb02 docs(pending_messages): fix two stale comments in pending_messages.py
Round 4 review nits:

- ``_PUSH_LUA`` block comment mentioned "returns 0 from our earlier
  LLEN" which was a leftover from an earlier design that had a
  separate LLEN check. The atomicity guarantee doesn't depend on it.
  Reworded to describe Redis EVAL serialisation instead.
- ``clear_pending_messages`` docstring said "called at the end of a
  turn" but the finally-block call sites were removed in round 2
  when the atomic drain-at-start became the primary consumer. The
  function is now only an operator/debug escape hatch. Docstring
  updated to match.

No behavioural change.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 16:15:02 +00:00
majdyz
ded048bdfb Merge remote-tracking branch 'origin/dev' into feat/copilot-pending-messages 2026-04-10 23:14:22 +07:00
majdyz
80e580f387 fix(baseline): mirror drained pending messages into transcript_builder
Round 3 follow-up: the drain-at-start in ``stream_chat_completion_baseline``
persisted pending messages to ``session.messages`` but never called
``transcript_builder.append_user`` for them.  A mid-turn transcript
upload would be missing the drained text, which could produce a
malformed assistant-after-assistant structure on the next turn.

The drain block runs BEFORE ``transcript_builder`` is instantiated
(which happens after prompt/transcript async setup), so we can't call
append_user in the drain block itself.  Instead, we remember the
drained list and mirror it into the transcript right after the
single-message ``transcript_builder.append_user(content=message)``
call near the prompt-build site.

Also cleaned up the stray adjacent-string concatenation in the log
line (``"...turn start " "for session %s"`` → single string).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 16:10:34 +00:00
majdyz
f140e73150 fix(copilot): address round 2 review on pending-messages feature
Critical: SDK path was double-injecting.  The endpoint persisted the
message to ``session.messages`` AND the executor drained it from Redis
and concatenated into ``current_message`` — the LLM saw each queued
message twice (once via the compacted history / gap context that
``_build_query_message`` pulls from ``session.messages``, once via
the new query).  Baseline avoided this via ``maybe_append_user_message``
dedup but SDK had no equivalent guard.

### Fix: Redis is the single source of truth

- Endpoint no longer persists to ``session.messages``.  It only
  pushes to Redis and returns.
- Baseline drain-at-start calls ``maybe_append_user_message`` (dedup
  is a safety net, not the primary guard).
- SDK drain-at-start calls ``maybe_append_user_message`` too, so the
  durable transcript records the queued messages.  The concatenation
  into ``current_message`` stays so the SDK CLI sees the content in
  the first user message of the new turn.

### Baseline max-iterations silent-loss — Fixed

``tool_call_loop`` yields ``finished_naturally=False`` when
``iteration == max_iterations`` then returns.  Previously the drain
only skipped ``finished_naturally=True``, so messages drained on the
max-iterations final yield were appended to ``openai_messages`` and
silently lost (the loop was already exiting).  Now the drain also
skips when ``loop_result.iterations >= _MAX_TOOL_ROUNDS``.

### API response cleanup

- ``QueuePendingMessageResponse``: dropped ``queued`` (always True) and
  ``detail`` (human-readable, clients shouldn't parse).  Kept
  ``buffer_length``, ``max_buffer_length``, and ``turn_in_flight``.

### Tests

- Removed dead ``_FakePipeline`` class (the code switched to Lua EVAL
  in round 1 so the pipeline fake was unused).
- Added ``test_drain_decodes_bytes_payloads`` so the ``bytes → str``
  decode branch in ``drain_pending_messages`` is actually exercised
  (real redis-py returns bytes when ``decode_responses=False``).
- Updated ``_FakeRedis.lists`` type hint to ``list[str | bytes]``.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 15:57:57 +00:00
majdyz
5b27ccf908 refactor(backend/orchestrator): replace charge_per_llm_call flag with extra_credit_charges hook + async charge methods
- Replace ClassVar[bool] charge_per_llm_call in Block._base with
  extra_credit_charges(execution_stats) -> int method; OrchestratorBlock
  overrides it to return max(0, llm_call_count - 1)
- Make charge_extra_iterations and charge_node_usage async; sync work
  runs via run_in_executor. charge_node_usage now folds in
  _handle_low_balance so callers don't touch private methods
- orchestrator.py no longer calls execution_processor._handle_low_balance
  directly; just awaits charge_node_usage which handles it internally
- Update tests: TestChargePerLlmCallFlag -> TestExtraCreditCharges,
  all async charge method tests converted to @pytest.mark.asyncio,
  added TestChargeNodeUsage assertions for _handle_low_balance calls
2026-04-10 22:56:38 +07:00
majdyz
cafe49f295 fix(copilot): address round 1 review on pending-messages feature
Critical fix — the SDK mid-stream injection was structurally broken.
``ClaudeSDKClient.receive_response()`` explicitly returns after the
first ``ResultMessage``, so re-issuing ``client.query()`` and setting
``acc.stream_completed = False`` could never restart the iteration —
the next ``__anext__`` raised ``StopAsyncIteration`` and the injected
turn's response was never consumed.  Replaced the broken mid-stream
path with a turn-start drain that works for both baseline and SDK.

### Changes

**Atomic push via Lua EVAL** (``pending_messages.py``)
- Replace the ``RPUSH`` + ``LTRIM`` + ``EXPIRE`` + ``LLEN`` pipeline
  (which was ``transaction=False`` and racy against concurrent
  ``LPOP``) with a single Lua script so the push is atomic.
- Drop the unused ``enqueued_at`` field.
- Add 16k ``max_length`` cap on ``PendingMessage.content``.

**Baseline path** (``baseline/service.py``)
- Drain at turn start (atomic ``LPOP``): any message queued while the
  session was idle or between turns is picked up before the first
  LLM call.
- Mid-loop drain now skips the final ``tool_call_loop`` yield
  (``finished_naturally=True``) — draining there would append a user
  message the loop is about to exit past, silently losing it.
- Inject via ``format_pending_as_user_message`` so file IDs + context
  are preserved in both ``openai_messages`` and the persisted session
  transcript (previously the DB copy lost file/context metadata).
- Remove the ``finally`` ``clear_pending_messages`` — atomic drain at
  turn start means any late push belongs to the next turn; clearing
  here would racily clobber it.

**SDK path** (``sdk/service.py``)
- Remove the broken mid-stream injection block entirely.
- Drain at turn start (same atomic ``LPOP``) and merge the drained
  messages into ``current_message`` before ``_build_query_message``,
  so the SDK CLI sees them as part of the initial user message.
- Remove the ``finally`` ``clear_pending_messages``.
- Delete the unused ``_combine_pending_messages`` helper.

**Endpoint** (``api/features/chat/routes.py``)
- Enforce ``check_rate_limit`` / ``get_global_rate_limits`` — was
  bypassing per-user daily/weekly token limits that ``/stream``
  enforces.
- ``QueuePendingMessageRequest`` gets ``extra="forbid"`` and
  ``message: max_length=16_000``.
- Push-first, persist-second: if the Redis push fails we raise 5xx;
  previously the session DB got an orphan user message with no
  corresponding queued entry and a retry would duplicate it.
- Log a warning when sanitised file IDs drop unknown entries.
- Persisted message content now uses ``format_pending_as_user_message``
  so the session copy matches what the model actually sees on drain.
- Response returns ``buffer_length``, ``max_buffer_length``, and
  ``turn_in_flight`` so the frontend can show accurate feedback about
  whether the message will hit the current turn or the next one.

**Tests** (``pending_messages_test.py``)
- ``_FakeRedis.eval`` emulates the Lua push script so the existing
  push/drain/cap tests keep working under the new atomic path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 15:37:40 +00:00
majdyz
c6a31cb501 feat(copilot): inject user messages mid-turn via pending buffer
When a user sends a follow-up message while a copilot turn is still
streaming, we now queue it into a per-session Redis buffer and let the
executor currently processing the turn drain it between tool-call
rounds — the model sees the new message before its next LLM call.
Previously such messages were blocked at the RabbitMQ/cluster-lock
layer and only processed after the current turn completed.

### New module
`backend/copilot/pending_messages.py`
- Redis list buffer keyed by ``copilot:pending:{session_id}``
- Pub/sub notify channel as a wake-up hint for future blocking-wait use
- Cap of ``MAX_PENDING_MESSAGES=10`` — trims oldest on overflow
- 1h TTL matches ``stream_ttl`` default
- Helpers: ``push_pending_message``, ``drain_pending_messages``,
  ``peek_pending_count``, ``clear_pending_messages``,
  ``format_pending_as_user_message``

### New endpoint
`POST /sessions/{session_id}/messages/pending`
- Returns 202 + current buffer length
- Persists the message to the DB so it's in the transcript immediately
- Sanitises file IDs against the caller's workspace
- Does NOT start a new turn (unlike ``stream``)

### Baseline path (simple — in-process injection)
`backend/copilot/baseline/service.py`
- Between iterations of ``tool_call_loop``, drain pending and append to
  the shared ``openai_messages`` list so the loop picks them up on the
  next LLM call
- Persist session via ``upsert_chat_session`` after injection
- Finally-block safety net clears the buffer on early exit

### SDK path (in-process injection via live client.query)
`backend/copilot/sdk/service.py`
- When the SDK loop detects ``acc.stream_completed``, before breaking,
  drain pending and send them via the existing open ``client.query()``
  as a new user message; reset ``stream_completed`` to ``False`` and
  ``continue`` the async-for loop so we keep consuming CLI messages
- Combines multiple drained messages into a single ``query()`` call via
  ``_combine_pending_messages`` to preserve ordering
- Finally-block safety net clears the buffer on early exit
- This works because the Claude Agent SDK's ``ClaudeSDKClient`` is a
  long-lived connection: ``query()`` writes a new user message to the
  CLI's stdin and the same ``receive_response()`` stream picks up the
  next turn's events, so we keep session continuity without releasing
  the cluster lock or restarting the subprocess

### Tests
`backend/copilot/pending_messages_test.py`
- FakeRedis + FakePipeline so tests don't need a live Redis
- Covers push/drain, ordering, buffer cap (MAX_PENDING_MESSAGES),
  clear, publish hook, malformed-payload handling, and the format
  helper (plain / with context / with file_ids)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 15:15:52 +00:00
majdyz
4525869a75 fix(executor): wrap _handle_low_balance in to_thread to avoid blocking
The post-execution low-balance check in on_node_execution called
_handle_low_balance directly. _handle_low_balance does sync DB work,
and on_node_execution is async, so the call blocked the event loop.

Sentry caught it. Wrap in asyncio.to_thread.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 13:34:01 +00:00
majdyz
389b2f4fb2 test(orchestrator): align IBE test with no-error-set fix
The test asserted result_stats.error is set to the IBE, but commit
b662eab36 removed that line from the IBE handler in on_node_execution
because it caused node_error_count++ inconsistency and leaked balance
amounts into persisted node_stats. Update the test to assert
result_stats.error is None (the structured ERROR log is the alerting
hook now, not the .error field).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 13:33:04 +00:00
majdyz
7804e03e7a style(builder-chat): apply prettier formatting to actionApplicators
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 13:20:46 +00:00
majdyz
6469334ae7 fix(executor): notify user when nested tool charge raises IBE
When the OrchestratorBlock's nested tool charge in
_execute_single_tool_with_manager raises InsufficientBalanceError, the
exception propagates up through on_node_execution and the node is
marked FAILED. The main queue's _charge_usage IBE catch (line 1305)
doesn't fire because the initial _charge_usage already succeeded —
only the nested tool charge failed.

This left the user without any notification about why their agent run
stopped. Add a status==FAILED + isinstance(error, IBE) branch in
on_node_execution that calls _handle_insufficient_funds_notif. The
notification flow goes through the same Redis dedup as the main queue
path so repeat runs don't spam.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 13:18:32 +00:00
majdyz
277c196426 fix(builder-chat): guard parsed actions against flowID navigation race
When flowID changes, the flow-reset effect clears messages and
parsedActions but those state updates aren't committed until the next
render. The parsedActions effect could run on a render where the
currentFlowIDRef still holds the previous flowID, briefly re-populating
parsedActions from stale messages and flashing old action buttons in
the new chat panel.

Fix: skip parsing when currentFlowIDRef.current !== flowID, and add
flowID to the effect deps so the effect re-runs once the ref catches up.

Addresses sentry finding 13127725.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 13:15:32 +00:00
majdyz
b29f160849 fix(executor): structured ERROR log on unexpected post-exec billing failures
The catch-all `except Exception` in the post-execution charge path was
logging at WARNING and dropping the error. Sentry flagged this as a
silent billing leak risk.

Now logs at ERROR with the same `billing_leak: True` structured marker
used by the InsufficientBalanceError branch, plus error_type/error/
extra_iterations fields and exc_info=True for the full traceback.
Monitoring/alerting can pick up either failure mode via the same key.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 13:10:07 +00:00
majdyz
c7885e32ec fix(builder-chat): differential undo for applied graph actions
Undo for chat-applied graph edits now reverts only the specific field
or edge that was changed, rather than restoring a full snapshot of the
nodes/edges arrays. Restoring a whole-array snapshot discarded any
subsequent manual edits the user made after clicking Apply, which was
flagged as a high-severity regression.

- applyUpdateNodeInput: snapshot only the previous value of the single
  field that is about to change. The undo closure re-reads live nodes
  at undo time and only rewrites action.key on the target node. If the
  field did not exist pre-apply, undo deletes it from hardcodedValues.
- applyConnectNodes: drop the pre-clone of edges entirely. The undo
  closure re-reads live edges at undo time and filters out the single
  edge matching source/target/handles, preserving other edges that
  were added afterwards.
- Tests updated to assert differential behavior (later unrelated edits
  are preserved through undo for both nodes and edges).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 13:02:40 +00:00
majdyz
743f1f82c9 style: black formatting on test_orchestrator.py
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 12:45:41 +00:00
majdyz
fddd23435f fix(orchestrator): address round 3 sentry + coderabbit findings
1. Don't set execution_stats.error on post-hoc IBE
   - Setting it caused node_error_count++ at line 762, creating an
     inconsistent "errored COMPLETED node" in graph metrics
   - Also leaked balance amounts into persisted node_stats
   - Now: log structured ERROR, notify user, leave node_stats clean
   - Fixes sentry finding 3064117116

2. Wire low-balance notifications for nested tool charges
   - charge_node_usage returned (cost, balance) but caller dropped balance
   - Now invokes _handle_low_balance after a successful tool charge,
     mirroring the main queue behaviour
   - Fixes coderabbit finding 3064103939

3. Lint: black formatting on test_orchestrator_per_iteration_cost.py

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 12:42:27 +00:00
majdyz
41f4064555 fix(frontend/builder): address round-3 review feedback
- Pure render: move parsedActions incremental parsing from useMemo into a
  useEffect (mutating refs inside a memo breaks React Strict Mode). Also
  move currentFlowIDRef sync into an effect so the render body stays pure.
- actionApplicators: extract DEFAULT_EDGE_MARKER_COLOR constant, add a
  shared safeCloneArray<T> helper, and export cloneNodes for unit tests.
- Add dedicated actionApplicators.test.ts (21 tests) covering validation,
  undo snapshot isolation, dangerous-key blocking, idempotent connect,
  and the structuredClone fallback path.
- Add MessageList.test.ts covering normalizePartForRenderer with a runtime
  type guard (isDynamicToolPart) so the unsafe double cast has a real
  regression test.
- Add useBuilderChatPanel tests for sendRawMessage length clamp (empty,
  canSend guard, under-cap passthrough, 4000-char truncation).
- Add BUILDER_CHAT_PANEL default tests to envFlagOverride test suite.
- Memoize visibleMessages filter and nodeMap map construction so they do
  not rebuild on every streaming re-render.
- Accessibility: add focus-visible ring to the toggle button; mark the
  Applied badge with role=status + aria-live=polite for screen readers.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 12:32:25 +00:00
majdyz
613321180e fix(orchestrator): propagate InsufficientBalanceError past outer loop catch-all
Follow-up to 6d2d476 addressing sentry's newly-posted review finding.

Both `_execute_tools_agent_mode` and `_execute_tools_sdk_mode` had broad
`except Exception` blocks at the top level of the tool-calling loop that
would still swallow `InsufficientBalanceError` even after the inner tool
executor carve-outs re-raised it: the error would escape the inner
`_execute_single_tool_with_manager`, propagate up through
`_agent_mode_tool_executor`, then get caught by the outer loop's
catch-all and converted into a user-visible "error" yield.

Add explicit `except InsufficientBalanceError: raise` carve-outs before
each broad handler so billing failures propagate all the way out of the
block's `run()` generator, reaching the executor's billing-leak handling
(error recording on execution_stats, user notification, structured log).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 12:00:08 +00:00
majdyz
ada2725628 fix(orchestrator): propagate billing errors and close leak windows
Round 3 review fixes on top of the per-iteration cost charging PR:

- Propagate InsufficientBalanceError out of `_agent_mode_tool_executor`
  and the SDK MCP tool handler so billing failures stop the agent loop
  instead of being re-injected into the LLM as a tool error (which
  previously leaked balance amounts and let the loop keep consuming
  unpaid compute).
- On post-execution extra-iteration charging failure, record
  `execution_stats.error`, log with structured billing_leak fields,
  and fire `_handle_insufficient_funds_notif` so the user is actually
  notified. Comment now matches behaviour.
- Tighten tool-success gate to `tool_node_stats.error is None` so
  cancelled/terminated tool runs (BaseException subclasses such as
  CancelledError) are not billed.
- Extract shared `_resolve_block_cost` helper used by `_charge_usage`
  and `charge_extra_iterations` to DRY the block/cost lookup.
- Add integration tests for the `on_node_execution` charging gate
  covering each branch (status, flag, llm_call_count, dry_run) plus
  the InsufficientBalanceError path that asserts error recording and
  notification.
- Add tool-charging skip tests (dry_run, failed tool, cancelled tool)
  and an InsufficientBalanceError propagation test for
  `_execute_single_tool_with_manager`.
- Assert `charge_node_usage` is actually called in the existing
  `test_orchestrator_agent_mode` test and return a non-zero cost so
  the `merge_stats` branch is exercised.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 11:53:46 +00:00
majdyz
16c38c4dfb style(credit): apply Black formatting
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 09:59:42 +00:00
majdyz
215340690f docs: regenerate block docs (memory_search/memory_store tools)
Auto-generated by `poetry run python scripts/generate_block_docs.py`.
The OrchestratorBlock tool list gained `memory_search` / `memory_store`
on dev but the doc table wasn't regenerated.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 09:59:00 +00:00
majdyz
45d67cfacc style: fix isort import grouping in test_orchestrator_per_iteration_cost
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 09:56:39 +00:00
majdyz
86f6695a33 refactor(frontend/builder): split chat panel into smaller files + address latest review
Addresses the "Should Fix" and "Nice to Have" items from the latest automated review.

- Extracts PanelHeader, MessageList, ActionList, PanelInput, TypingIndicator
  into a local components/ folder so BuilderChatPanel.tsx drops from 452 to
  ~128 lines (under the AGENTS.md 200-line guideline).
- Extracts handleApplyAction branches into applyUpdateNodeInput /
  applyConnectNodes helpers in actionApplicators.ts, and shares a pushUndoEntry
  helper to DRY the MAX_UNDO trimming. useBuilderChatPanel.ts drops from 616
  to ~510 lines.
- Uses structuredClone() for undo snapshots so restore callbacks are isolated
  from in-place mutations (falls back to a shallow copy on unsupported envs).
- Incremental action parsing: lastParsedMessageIndexRef + parsedActionsCacheRef
  avoid the O(all_messages) re-scan per turn.
- Adds a simple LRU cap (MAX_SESSION_CACHE = 50) to graphSessionCache so the
  module-scope Map cannot grow unbounded across navigations.
- sendRawMessage now clamps to MAX_RAW_MESSAGE_LENGTH so programmatic callers
  cannot bypass the textarea length cap.
- TypingIndicator gains role="status"/aria-label for screen readers.
- PanelInput shows a character counter once >=80% of maxLength and highlights
  red at the limit.
- Panel container uses max-h-[70vh] (with min-h-[320px] and sm:max-h-[75vh])
  so it gracefully shrinks on small screens instead of overlapping the
  builder toolbar.
- normalizePartForRenderer extracted from the inline dynamic-tool transform.
- Adds BuilderChatPanel.test.tsx coverage for connect_nodes action label
  rendering (with customized_name + fallback to raw node id).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 09:52:27 +00:00
majdyz
945297b965 fix(backend): cancel trialing Stripe subs alongside active ones
_cancel_customer_subscriptions previously only queried status="active",
leaving trialing subscriptions in place. A user on a trial who downgrades
to FREE, or upgrades to a different paid tier, would continue to be billed
once the trial ended. Query both "active" and "trialing" statuses and
dedupe by sub id to ensure every billable sub is cleaned up.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 09:17:38 +00:00
majdyz
17a9ff1278 fix(orchestrator): address review feedback on per-iteration cost charging
Addresses critical issues from autogpt-pr-reviewer + coderabbit review:

Billing safety:
- Surface InsufficientBalanceError as ERROR (not warning) so monitoring
  picks up billing leaks; other charge failures still log a warning.
- Cap extra_iterations at MAX_EXTRA_ITERATIONS=200 to prevent a corrupted
  llm_call_count from draining a user's balance.
- Tools now charged AFTER successful execution, not before — failed
  tools no longer cost credits, matching the rest of the platform.
- charge_node_usage uses execution_count=0 so nested tool calls don't
  inflate the per-execution counter / push users into higher cost tiers.
- charge_extra_iterations now returns (cost, remaining_balance) and the
  caller invokes _handle_low_balance to send low-balance notifications.

Error handling consistency:
- _execute_single_tool_with_manager re-raises InsufficientBalanceError
  instead of swallowing it into a tool-error response. This prevents
  leaking the user's exact balance to the LLM context and lets the
  outer error handling stop the run cleanly, mirroring the main queue.

Test fixes:
- test_orchestrator_per_iteration_cost.py: rewritten with pytest
  monkeypatch fixtures (no more manual save/restore), proper FakeBlock
  with .name attribute set correctly, plus new tests for the cap,
  block-not-found, InsufficientBalanceError propagation, and
  charge_node_usage delegation.
- test_orchestrator.py / test_orchestrator_responses_api.py /
  test_orchestrator_dynamic_fields.py: mock charge_node_usage on the
  execution processor stub so existing agent-mode tests still pass
  with the new charging call.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 08:57:40 +00:00
majdyz
6b57dc0c7f fix(backend): prevent race-condition downgrade in Stripe webhook handler
When Stripe processes a subscription upgrade, the old subscription's
customer.subscription.deleted event may arrive after the new subscription's
customer.subscription.created has already been handled. Unconditionally
setting the user to FREE in the cancel branch would immediately undo the
upgrade.

sync_subscription_from_stripe now checks Stripe for other active/trialing
subscriptions on the same customer before downgrading. If at least one
different active sub exists, the handler preserves the current tier and
returns without writing. Added a regression test that mocks Stripe
returning sub_new as active and asserts set_subscription_tier is never
awaited.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 08:49:23 +00:00
majdyz
c1aec96c0f fix(platform): address round-2 review comments on subscription billing
Security and quality fixes for PR #12727 subscription tier billing review:

- Open-redirect protection: validate success_url/cancel_url against
  settings.config.frontend_base_url before passing to Stripe Checkout.
- Blocking I/O: wrap every synchronous Stripe SDK call (Subscription.list,
  Subscription.cancel, checkout.Session.create) with run_in_threadpool via
  a shared _cancel_customer_subscriptions helper.
- Info leakage: log raw Stripe errors server-side but return a generic
  502 detail to the client ("Please try again or contact support.").
- Webhook idempotency: skip DB writes in sync_subscription_from_stripe
  when the tier is already current, avoiding redundant writes on retry.
- ENTERPRISE guard in webhook: refuse to overwrite ENTERPRISE tier from
  Stripe events (admin-managed, not self-service).
- create_subscription_checkout raises ValueError on empty session.url
  instead of silently returning "".
- Tests: fixture-based client (no leaky try/finally), open-redirect test,
  ENTERPRISE 403 test, webhook dispatch test, trialing status test,
  multi-sub partial-cancel-failure test, idempotency test, renamed
  misleading "defaults to FREE" tests to "preserves_current_tier".

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 08:44:01 +00:00
majdyz
f520b64693 fix(backend/orchestrator): charge per LLM iteration and per tool call
The OrchestratorBlock in agent mode makes multiple LLM calls in a single
node execution, but the executor was only charging the user once per run.
Tools spawned by the orchestrator also bypassed _charge_usage entirely,
producing free internal block executions.

Two fixes:

1. Per-iteration LLM charging
   - Add `Block.charge_per_llm_call` class flag (default False)
   - OrchestratorBlock sets it to True
   - After on_node_execution completes, the executor calls
     `charge_extra_iterations(node_exec, llm_call_count - 1)` which
     debits `base_cost * extra_iterations` additional credits via
     spend_credits, using the same per-model cost from BLOCK_COSTS.
   - Skipped for dry runs and failed runs.

2. Tool execution charging
   - In `_execute_single_tool_with_manager`, call the new public
     `ExecutionProcessor.charge_node_usage()` before invoking
     `on_node_execution()` for the spawned tool node.
   - Tools now incur the same credit cost as queue-driven node
     executions; the cost is also added to the orchestrator's
     `extra_cost` so it shows up in graph stats.

Tests cover the flag opt-in, the charge_extra_iterations math
(positive, zero, negative iterations, zero base cost).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 08:21:41 +00:00
majdyz
1921795e42 fix(frontend/builder): wrap BuilderChatPanel in ErrorBoundary
Prevents a runtime error in action parsing or message rendering from
crashing the entire build page. Uses a null fallback so the rest of the
editor remains usable if the chat panel fails.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 07:01:31 +00:00
majdyz
52b0e2a9a6 fix(backend): cancel stale Stripe subs on paid-to-paid tier upgrade
When a PRO user upgrades to BUSINESS via a fresh Checkout Session, Stripe
creates a new subscription without touching the existing one, leaving the
customer double-billed. Cleaning up in sync_subscription_from_stripe
rather than the API handler ensures an abandoned Checkout does not leave
the user without a subscription: we only cancel the old sub once the new
sub has actually become active.

Errors listing or cancelling stale subs are logged but not propagated —
the new subscription tier still gets persisted, and Stripe will retry
the webhook later if listing fails.

Addresses sentry[bot] comment 3061713750 on PR #12727.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 06:54:58 +00:00
majdyz
3ef14e9657 fix(backend): invalidate get_user_tier cache in set_subscription_tier
After a tier change, the rate-limit cache (get_user_tier, 5-minute TTL)
was not cleared, so CoPilot rate limits would continue enforcing the old tier
until the TTL expired. Call get_user_tier.cache_delete(user_id) via a local
import to avoid circular import issues.

Addresses sentry[bot] comment 3061725912 on PR #12727.
2026-04-10 09:43:51 +07:00
majdyz
3c49d3373d fix(backend): remove invalid customer_update parameter from Stripe checkout
customer_update only accepts {address, name, shipping} per Stripe's TypedDict.
The payment_method key does not exist in CreateParamsCustomerUpdate, so pyright
was failing the type-check CI. Remove the invalid parameter — for Stripe
subscriptions the payment method used for the first invoice is automatically
saved to the customer by Stripe.
2026-04-10 09:30:37 +07:00
majdyz
e7e6c8f4b4 refactor(frontend): remove unused legacy subscription methods from BackendAPI
getSubscription() and setSubscriptionTier() in client.ts were replaced by
generated hooks (useGetSubscriptionStatus, useUpdateSubscriptionTier) and
are no longer called anywhere in the codebase. Remove them to avoid adding
further surface area to the deprecated BackendAPI.
2026-04-10 09:25:42 +07:00
majdyz
503e5e1f38 fix(frontend/builder): gate seed message on isOpen to prevent empty-graph context poison
When navigating between graphs with the panel closed, the cached session
could trigger the seed message effect with isOpen=false, causing the nodes
selector to return EMPTY_NODES. This would send an empty-graph summary to
the AI, poisoning the context before the panel was even opened.

Fix: add `isOpen` guard at the top of the seed effect and include it in
the dependency array so the seed fires only when the panel is visible and
nodes reflect the actual graph state.

Add regression test: verifies seed is NOT sent when panel is closed even
when sessionId is cached from a prior navigation and isGraphLoaded is true.
2026-04-10 09:25:16 +07:00
majdyz
4b3e47fe88 fix(platform): propagate Stripe errors in cancel_stripe_subscription
- stripe.Subscription.list() is now wrapped in try-except; StripeError
  is logged and re-raised so callers know the listing failed.
- stripe.Subscription.cancel() StripeError is now re-raised (was swallowed),
  preventing set_subscription_tier from marking the user FREE when Stripe
  cancellation failed.
- update_subscription_tier catches StripeError from cancel and returns HTTP 502
  so DB tier is only updated if Stripe succeeds.
- Fix test patch path: use backend.data.credit.stripe.checkout.Session.create
  instead of bare stripe.checkout.Session.create for import-refactor safety.
- Add tests for raise-on-list-failure, raise-on-cancel-failure, and
  502 route response on cancel failure.

Addresses sentry[bot] comments 3061585490, 3061654688 on PR #12727.
2026-04-10 09:22:44 +07:00
majdyz
6b6ce9db27 test(frontend/builder): add run_agent tool-call detection tests and capture setQueryStates
Adds four tests covering run_agent tool-call handling in useBuilderChatPanel:
- sets flowExecutionID via setQueryStates when execution_id is valid
- skips setQueryStates when output has no execution_id
- rejects path-traversal execution_id (security: /^[\w-]+$/i validation)
- deduplicates run_agent via processedToolCallsRef

Captures mockSetQueryStates from nuqs mock so run_agent assertions can verify
the correct query-state mutation rather than just the absence of errors.
2026-04-10 09:18:14 +07:00
majdyz
cc1cef7da5 fix(platform): set customer default payment method on subscription checkout
Adds customer_update={payment_method: auto} so the payment method used
for subscription is set as the Stripe customer's default. Makes it show
pre-selected in future Checkout sessions (manual top-ups).
2026-04-10 09:02:16 +07:00
majdyz
d49f2518a2 fix(frontend/builder): always clear messages on flowID change to keep action state consistent
When navigating back to a cached session, appliedActionKeys was reset to empty
but messages were preserved. This caused previously applied actions to reappear
as unapplied in the UI, allowing them to be re-applied and creating duplicate
undo entries. Clearing messages unconditionally on navigation ensures the
displayed action buttons always reflect the actual applied state.
2026-04-10 02:03:56 +07:00
majdyz
31cb6a2f58 fix(frontend/builder): guard msg.parts with nullish coalescing to prevent runtime error 2026-04-10 01:41:15 +07:00
majdyz
a721fc689b fix(frontend/builder): clear stale messages in retrySession so new session starts clean 2026-04-10 00:56:31 +07:00
majdyz
dc0631809c fix(frontend/builder): reset hasSentSeedMessageRef in retrySession so seed is sent to new session 2026-04-10 00:39:10 +07:00
majdyz
3e9f3d33f9 fix(frontend/builder): wrap panel in CopilotChatActionsProvider to prevent crash
EditAgentTool and RunAgentTool call useCopilotChatActions() which throws
if no provider is in the tree. Wrap the panel content with
CopilotChatActionsProvider wired to sendRawMessage so tool components
can send retry prompts without crashing.
2026-04-09 23:41:06 +07:00
majdyz
5c7a8e1267 fix(frontend/builder): skip Escape-to-close when focus is in textarea/input
Pressing Escape while drafting a message was silently discarding the
user's text. Guard the handler so it only closes the panel when focus is
outside an editable element.
2026-04-09 23:15:56 +07:00
majdyz
969d9cfc41 test(frontend/builder): restore seed-message tests + guard empty messages array
- Re-add describe block for seed message sending (removed in 8b8eb80480):
  - verifies sendMessage is called with buildSeedPrompt when isGraphLoaded=true
  - verifies sendMessage is NOT called when isGraphLoaded=false (default)
  - verifies the hasSentSeedMessageRef guard fires only once per session
- Add test for empty messages guard in prepareSendMessagesRequest
- Guard messages.at(-1) in prepareSendMessagesRequest with an early throw
  so a runtime TypeError cannot occur if the AI SDK contract is violated
2026-04-09 22:15:53 +07:00
majdyz
0b06e948b9 fix(frontend/builder): hide seed message from visible chat messages
Import SEED_PROMPT_PREFIX in BuilderChatPanel and extend the
visibleMessages filter to exclude any user message whose text starts
with the prefix. Adds a regression test for the new filter.
2026-04-09 16:49:18 +07:00
majdyz
a3c97c14c1 fix(frontend/builder): render tool calls via MessagePartRenderer normalization
- Fix visibleMessages filter: assistant messages with only dynamic-tool parts
  (no text) were silently hidden — now included when any dynamic-tool part exists
- Normalize dynamic-tool parts to tool-{toolName} before rendering so
  MessagePartRenderer routes them correctly: edit_agent and run_agent get their
  existing copilot renderers, all other tools fall through to GenericTool
  (collapsed accordion with icon, status text, expandable output)
2026-04-09 13:34:17 +07:00
majdyz
ebec9a11f9 fix(frontend): prevent cross-graph session assignment in concurrent navigation
Track effectFlowID at session creation start and compare against currentFlowIDRef
after the async postV2CreateSession resolves. If the user navigated to a different
graph before the response arrived, the old session ID is discarded instead of
being committed to the new graph's state, preventing chat history from being
crossed between graphs.
2026-04-09 12:06:33 +07:00
majdyz
6673972659 fix(frontend): block prototype-polluting keys without schema + validate execution_id
- Add DANGEROUS_KEYS blocklist (__proto__, constructor, prototype) checked before
  the schema guard in handleApplyAction so schema-less nodes cannot be polluted
  via AI-supplied keys
- Validate execution_id from run_agent tool output with /^[\w-]+$/i before
  passing to setQueryStates, preventing URL-special characters from entering
  query state
- Add tests for DANGEROUS_KEYS blocklist on schema-less nodes (three cases)
2026-04-09 11:48:33 +07:00
majdyz
703b1ff078 test(frontend): add tool-call detection + session ID validation tests; fix EMPTY_NODES ref
- Add tests for edit_agent tool call detection: verifies onGraphEdited fires on
  output-available state, is suppressed during streaming, and is not called twice
  for the same toolCallId (processedToolCallsRef deduplication)
- Add tests for session ID validation: verifies that path-traversal IDs
  (../../admin) and IDs with spaces set sessionError and leave sessionId null
- Extract EMPTY_NODES module-level constant to give useShallow a stable
  reference when the panel is closed, preventing spurious re-renders
2026-04-09 11:43:08 +07:00
majdyz
4fabc6ba16 fix(frontend): pass isGraphLoaded from Flow.tsx + Escape key containment check
- Wire isInitialLoadComplete as isGraphLoaded prop in Flow.tsx so the seed
  message effect in useBuilderChatPanel actually fires once the graph is ready
- Add panelRef to BuilderChatPanel and pass it to the hook so the Escape key
  listener only closes the panel when focus is inside it, preventing conflicts
  with other dialogs or canvas keyboard handlers
- Update BuilderChatPanel test to use objectContaining for the hook call
  assertion, accommodating the new panelRef argument
2026-04-09 11:11:39 +07:00
majdyz
5b798c6511 feat(frontend/builder): use copilot MessagePartRenderer for message rendering
Replace the simplified ReactMarkdown block in BuilderChatPanel's MessageList
with MessagePartRenderer from the copilot panel, enabling proper rendering of
tool invocations, error markers, and system markers in addition to text parts.
2026-04-09 11:04:46 +07:00
majdyz
3be23b20a4 fix(frontend): restore seed message + fix prototype pollution + clear session cache in tests
- Restore isGraphLoaded prop and hasSentSeedMessageRef seed-message effect that
  were removed in a prior external modification; all seed-message tests now pass
- Apply Object.prototype.hasOwnProperty.call() guard in inline handleApplyAction
  for input-schema and handle validation (three sites), matching the extracted
  helper functions; prototype-pollution tests now pass
- Export clearGraphSessionCacheForTesting() and call it in beforeEach to prevent
  stale module-level graphSessionCache from leaking across tests (fixes flowID
  reset test)
- Update BuilderChatPanel test to expect isGraphLoaded in useBuilderChatPanel call
- Remove unused Dispatch, SetStateAction, CustomEdge, CustomNode imports
2026-04-09 11:03:04 +07:00
majdyz
7c789923ec feat(frontend/builder): persistent session per graph, no auto-send, tool detection
- Remove auto-send seed message on chat open (user initiates context manually)
- Cache chat session per graph ID (module-level Map) so reopening the panel for
  the same graph reuses the existing session and preserves conversation history
- Detect edit_agent tool completion → trigger graph refetch via onGraphEdited callback
- Detect run_agent tool completion → update flowExecutionID in URL to auto-follow run
- retrySession now evicts the stale cache entry so a fresh session is created
- Flow.tsx passes refetchGraph as onGraphEdited to BuilderChatPanel
2026-04-09 10:58:53 +07:00
majdyz
3586ad64c1 fix(frontend/builder): address reviewer feedback — prototype pollution, function length, textarea maxLength, and test coverage
- Fix prototype pollution bypass: use Object.prototype.hasOwnProperty.call instead of `in` operator for schema key validation, preventing __proto__/constructor injection through schema-validated nodes
- Extract applyUpdateNodeInput and applyConnectNodes as module-level helpers to reduce handleApplyAction from 165 lines to a 20-line dispatcher
- Add JSDoc to useBuilderChatPanel documenting session lifecycle, transport, seed message, action parsing, undo, and input responsibilities
- Add maxLength=4000 to PanelInput textarea to cap token usage
- Add prototype pollution tests (__proto__ and constructor keys rejected when inputSchema is present)
- Strengthen Send-button-disabled assertion in component test
2026-04-09 10:47:15 +07:00
majdyz
a45fa418a8 feat(frontend/builder): add typing indicator animation to builder chat panel
Shows three bouncing dots in an assistant-style bubble while waiting
for the first response token (status submitted, no assistant text yet).
Disappears once streaming begins and text appears.
2026-04-09 10:37:38 +07:00
majdyz
abcf0830a6 fix(frontend/builder): address reviewer comments on BuilderChatPanel
- Overlapping placeholders: add !seedMessage guard to empty-state block so the
  "Ask me to explain…" and "Graph context sent" banners are mutually exclusive
- aria-modal without focus trap: replace role="dialog"/aria-modal="true" with
  role="complementary" since this is a side panel, not a blocking modal
- Stale closure in handleApplyAction: use useNodeStore/useEdgeStore.getState()
  for both validation and mutation so rapid applies see live data
- Gate nodes/edges Zustand subscriptions behind isOpen to prevent chat-panel
  hook re-running on every node drag/resize when panel is closed
- inputValue not cleared on flowID change: add setInputValue("") to flowID reset
- ReactMarkdown links: add custom <a> component with target="_blank" and rel="noopener noreferrer"
- XML sanitization: apply sanitizeForXml() to n.id and edge handle names
- Regex statefulness: move JSON_BLOCK_REGEX inside parseGraphActions() to avoid
  shared lastIndex state (eliminates fragile lastIndex=0 reset)
- Type guard soundness: add typeof p.text === "string" to extractTextFromParts filter
- Session ID validation: validate format before interpolating into streaming URL
- Shallow-copy undo snapshots: spread prevNodes/prevEdges so closures hold
  independent arrays
- Set spread optimisation: use new Set(prev).add(key) instead of new Set([...prev, key])
- Tests: remove dead getGetV1GetSpecificGraphQueryKey mock, add markerEnd assertion
  to connect_nodes tests, add transport prepareSendMessagesRequest coverage,
  add Enter-with-empty-input and inputValue-reset-on-flowID-change tests
2026-04-09 08:12:35 +07:00
majdyz
3107789867 test(backend): cover usd_to_microdollars(None) and get_platform_cost_logs with explicit start
Closes branch gaps in platform_cost.py (lines 29-31 and 312→314) that
were introduced via the dev merge but not exercised by existing tests.
This also forces the backend CI to run so Codecov uploads fresh coverage
instead of carrying forward stale data from before the cost-tracking
feature landed on dev.
2026-04-09 07:41:16 +07:00
majdyz
dc144b8323 fix(frontend/builder): guard extractTextFromParts against undefined parts
The AI SDK can return messages with undefined parts in certain error
scenarios. Accept null/undefined in extractTextFromParts and fall back
to an empty array to prevent a TypeError and component crash.
2026-04-09 06:55:32 +07:00
majdyz
f440563a67 fix(frontend/builder): clear stale chat messages on graph navigation
Adds a useEffect in useBuilderChatPanel that calls setMessages([]) whenever
the flowID query param changes, preventing old technical seed prompts from
the prior session briefly appearing when switching between agents.
2026-04-09 06:43:58 +07:00
majdyz
308e2469f1 fix(frontend/builder): add markerEnd to chat-applied edges so arrowheads render correctly
Chat panel used setEdges directly without the markerEnd property that edgeStore.addEdge
sets automatically. Added MarkerType.ArrowClosed with strokeWidth=2, color="#555" to
match the standard edge appearance.
2026-04-09 06:29:27 +07:00
majdyz
1ea6ce6ce4 fix(frontend/builder): address review blockers — duplicate edge guard, undo anti-pattern, stack cap, a11y, and test coverage
- Guard against duplicate connect_nodes edges: check prevEdges before applying,
  mark as already-applied without duplicating if edge exists
- Cap undo stack at MAX_UNDO=20 to prevent unbounded memory growth for large graphs
- Fix React anti-pattern: call restore() before setUndoStack updater instead of
  inside it (state updaters must be pure — no side effects)
- Add aria-modal="true" to dialog panel and aria-expanded to toggle button
- Extract IIFE nodeMap into ActionList sub-component (cleaner render path)
- Add 18 new tests: handleSend when canSend=false, Shift+Enter no-send,
  schema-absent permissive paths (update + connect_nodes), sequential multi-undo
  LIFO order, duplicate edge guard, undo stack size cap, empty stack no-op
2026-04-09 06:10:11 +07:00
majdyz
98e5668a6e fix(frontend/builder): prevent appliedActionKeys desync after global undo
Apply chat panel changes via setNodes/setEdges (bypassing history store)
so Ctrl+Z cannot revert them and leave the "Applied" badge stale.
Also hoist jsonBlockRegex to module scope, cap node description length
at 500 chars, and remove useShallow from single-value selectors.
2026-04-09 01:50:24 +07:00
majdyz
1f6981bd06 fix(frontend/builder): fix chat panel undo bypassing global history store
Use setNodes/setEdges directly in undo restore closures instead of
updateNodeData/removeEdge which push to the history store. This prevents
the global Ctrl+Z from re-applying changes that the user already undid via
the chat panel's own undo button.

Also removes unused removeEdge selector from the hook.
2026-04-09 01:36:17 +07:00
majdyz
dd602fefdc fix(frontend/builder): address review comments on builder chat panel
- Replace fragile setTimeout double-toggle retry with dedicated retrySession()
  callback that resets sessionError and lets the session-creation effect re-run
- Remove invalidateQueries after apply actions — caused server refetch to
  overwrite local Zustand state changes (sentry HIGH severity bug)
- Deep-clone prevHardcoded before undo capture so sequential applies to the
  same node each have an independent snapshot
- Remove unsolicited "What does this agent do?" question from seed prompt;
  invite user to initiate instead
- Remove useCallback from handleUndoLastAction per project convention
- Remove unused sendMessage and status from hook return
- Remove JSDoc comment from BuilderChatPanel per project convention
- Hoist nodeMap construction from ActionItem to parent parsedActions.map
  to avoid N identical Maps per render cycle
- Make useChat mock configurable (mockChatMessages/mockChatStatus) and add
  tests for parsedActions integration, Escape key handler, retrySession,
  and handleSend input-clearing behavior
2026-04-09 01:29:41 +07:00
majdyz
6afe84a4c2 Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/builder-chat-panel 2026-04-09 01:16:01 +07:00
majdyz
72800e8cb5 fix(frontend/builder): fix XML sanitization, add undo for connect_nodes, add hook tests
- sanitizeForXml now escapes &, ", ' in addition to < and >
- connect_nodes actions now push an undo snapshot (removeEdge) so they can be reverted like update_node_input
- useBuilderChatPanel.test.ts adds removeEdge mock and test for undo of connect_nodes
2026-04-08 23:59:26 +07:00
majdyz
5143652fb0 fix(frontend/builder): add hook tests and fix isCreatingSessionRef leak on navigation
- Restore useBuilderChatPanel.test.ts with 28 tests covering session lifecycle
  (create success, failure, non-200), seed message dispatch + only-once guard,
  flowID reset (sessionId, sessionError, appliedActionKeys), cache invalidation
  assertion after handleApplyAction, and undo stack behaviour
- Fix sentry-flagged bug: reset isCreatingSessionRef.current in the flowID
  change effect so navigating mid-session-creation doesn't permanently block
  future session creation on the new graph
2026-04-08 23:31:45 +07:00
majdyz
a539d4f787 fix(frontend/builder): address PR review — move logic to hook, undo, dedup fix, component tests
- Move inputValue, handleSend, handleKeyDown, isStreaming, canSend into
  useBuilderChatPanel (0ubbe: keep render logic out of component)
- Add undo support: snapshot node state before apply, expose undoStack +
  handleUndoLastAction, show undo button in PanelHeader
- Add toast feedback on handleApplyAction validation failures so users
  know why Apply did nothing
- Fix getActionKey for update_node_input to include value so AI corrections
  in later turns are not silently dropped by the dedup Set
- Add getNodeDisplayName shared helper in helpers.ts; use in both
  serializeGraphForChat and ActionItem (removes duplication)
- Use Map<id, node> in serializeGraphForChat for O(1) edge lookups
- Add Retry button to session error state in MessageList
- Add graph context sent banner after seed message so AI response
  does not appear unprompted (addresses confusing auto-response UX)
- Add aria-label to Apply buttons for screen-reader accessibility
- Remove hook-only test file (0ubbe: test component, not hook)
- Expand component tests: undo, retry, seed banner, action label format,
  getNodeDisplayName, getActionKey value-inclusion, edge truncation
- All 1026 tests pass; lint and types clean
2026-04-08 22:41:34 +07:00
majdyz
9a3236a80a fix(frontend/builder): address PR review — seed filter, validation, tests, session ref guard
- Filter seed message by content prefix (SEED_PROMPT_PREFIX) instead of position
- Add exhaustiveness guard for unhandled GraphAction types
- Guard handleApplyAction against unknown keys/handles via inputSchema/outputSchema
- Add renderHook-based tests: session lifecycle, flowID reset, handleApplyAction, edge cases
- Fix session-creation effect to use isCreatingSessionRef so state-driven re-renders
  don't prematurely cancel the in-flight request via the cancelled flag
- Add empty-input rejection test for BuilderChatPanel send button
2026-04-08 22:07:46 +07:00
majdyz
c1a28d54c2 fix(frontend/builder): require manual action confirmation and prevent prompt injection
- Replace auto-apply with per-action Apply buttons; users must explicitly
  confirm each AI suggestion before the graph is mutated
- Accumulate parsedActions across all assistant messages so multi-turn
  suggestions remain visible rather than disappearing after the next turn
- Escape < and > in node names/descriptions before embedding in XML prompt
  context to prevent AI prompt injection via crafted node labels
- Add MAX_EDGES cap (200) in serializeGraphForChat to mirror the MAX_NODES
  limit and prevent token overruns on dense graphs
- Add Escape key handler in the hook to close the chat panel
- Add helpers.test.ts with unit tests for buildSeedPrompt,
  extractTextFromParts, and XML sanitization
2026-04-08 18:41:58 +07:00
majdyz
b8cb9a506f fix(frontend/builder): hide seed message from chat UI
The initialization prompt ("I'm building an agent in the AutoGPT flow
builder...") was sent as a visible user message, exposing raw prompt
engineering instructions to end users. Track its ID via seedMessageId
and exclude it from the rendered message list.
2026-04-08 16:15:32 +07:00
majdyz
9f10e40c6b fix(frontend/builder): auto-apply AI graph actions after each streaming turn
handleApplyAction was defined and exported but never called, so the
"AI applied these changes" panel was displaying actions that had no
effect. Wire up a handleApplyActionRef so the status-change effect
can safely apply each parsed action to the local Zustand stores once
per completed AI turn, before the canvas refetch resolves.
2026-04-08 15:52:06 +07:00
majdyz
2841e01605 fix(frontend/builder): validate key and handle against node schemas in handleApplyAction
Rejects update_node_input keys not present in inputSchema.properties and
connect_nodes handles not present in outputSchema/inputSchema.properties,
preventing AI from writing arbitrary fields that blocks do not support.
Validation is permissive when schema is undefined (backwards-compatible).
2026-04-08 15:44:12 +07:00
majdyz
7f9486dea5 test(frontend/builder): add hook and component tests for handleApplyAction and session error
- Add useBuilderChatPanel.test.ts with direct tests for handleApplyAction:
  update_node_input (merges hardcodedValues, no-ops for unknown node),
  connect_nodes (calls addEdge with correct args, no-ops if either node missing)
- Add panel open/close state tests for useBuilderChatPanel
- Add session error UI test to BuilderChatPanel.test.tsx
2026-04-08 15:35:30 +07:00
majdyz
fe69c3412b refactor(frontend/builder): extract getActionKey helper, wire textareaRef
- Extract `getActionKey(action)` to helpers.ts, removing duplicated key
  computation from BuilderChatPanel.tsx and useBuilderChatPanel.ts
- Wire `textareaRef` through PanelInputProps so focus-on-open works
- Add `getActionKey` tests covering both action types
2026-04-08 15:08:40 +07:00
majdyz
9107986f5b fix(frontend/builder): escape quotes in welcome state to satisfy react/no-unescaped-entities 2026-04-08 15:00:08 +07:00
majdyz
b2caf6f1b0 fix(frontend/builder): resolve merge conflicts — keep comprehensive security & UX fixes
Merge resolution keeps:
- buildSeedPrompt helper (prompt injection mitigation with XML tags)
- extractTextFromParts naming (aligned with remote)
- cancelled flag pattern for session creation cleanup
- streamError display and empty/welcome state (new in this branch)
- Static Applied badge (span, no dead toggle logic)
- ARIA roles: role=dialog, role=log
- react-markdown for assistant messages
- Placeholder hint for Enter/Shift+Enter
- All new tests: keyboard, multi-action, customized_name, truncation,
  primitive validation, stream error, ARIA assertions
2026-04-08 14:53:35 +07:00
majdyz
dd3225a5c4 fix(frontend/builder): address review comments on chat panel
- Validate node existence before connect_nodes in handleApplyAction
- Add cleanup guard to session creation effect to prevent state updates
  after unmount
- Extract extractTextFromParts helper to deduplicate text extraction
- Remove dead code in ActionItem (applied state was always true)
- Remove redundant setTimeout scroll in handleSend (useEffect handles it)
- Update test to match simplified ActionItem
2026-04-08 07:43:22 +00:00
1011 changed files with 20293 additions and 126348 deletions

View File

@@ -25,8 +25,6 @@ Understand the **Why / What / How** before addressing comments — you need cont
gh pr view {N} --json body --jq '.body'
```
> If GraphQL is rate-limited, `gh pr view` fails. See [GitHub rate limits](#github-rate-limits) for REST fallbacks.
## Fetch comments (all sources)
### 1. Inline review threads — GraphQL (primary source of actionable items)
@@ -111,16 +109,12 @@ Only after this loop completes (all pages fetched, count confirmed) should you b
**Filter to unresolved threads only** — skip any thread where `isResolved: true`. `comments(last: 1)` returns the most recent comment in the thread — act on that; it reflects the reviewer's final ask. Use the thread `id` (Relay global ID) to track threads across polls.
> If GraphQL is rate-limited, see [GitHub rate limits](#github-rate-limits) for the REST fallback (flat comment list — no thread grouping or `isResolved`).
### 2. Top-level reviews — REST (MUST paginate)
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
```
> **Already REST — unaffected by GraphQL rate limits or outages. Continue polling reviews normally even when GraphQL is exhausted.**
**CRITICAL — always `--paginate`.** Reviews default to 30 per page. PRs can have 80170+ reviews (mostly empty resolution events). Without pagination you miss reviews past position 30 — including `autogpt-reviewer`'s structured review which is typically posted after several CI runs and sits well beyond the first page.
Two things to extract:
@@ -139,8 +133,6 @@ Two things to extract:
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
```
> **Already REST — unaffected by GraphQL rate limits.**
Mostly contains: bot summaries (`coderabbitai[bot]`), CI/conflict detection (`github-actions[bot]`), and author status updates. Scan for non-empty messages from non-bot human reviewers that aren't the PR author — those are the ones that need a response.
## For each unaddressed comment
@@ -335,65 +327,18 @@ git push
5. Restart the polling loop from the top — new commits reset CI status.
## GitHub rate limits
## GitHub abuse rate limits
Three distinct rate limits exist — they have different causes, error shapes, and recovery times:
Two distinct rate limits exist — they have different causes and recovery times:
| Error | HTTP code | Cause | Recovery |
|---|---|---|---|
| `{"code":"abuse"}` | 403 | Secondary rate limit — too many write operations (comments, mutations) in a short window | Wait **23 minutes**. 60s is often not enough. |
| `{"message":"API rate limit exceeded"}` | 429 | Primary REST rate limit — 5000 calls/hr per user | Wait until `X-RateLimit-Reset` header timestamp |
| `GraphQL: API rate limit already exceeded for user ID ...` | 403 on stderr, `gh` exits 1 | **GraphQL-specific** per-user limit — distinct from REST's 5000/hr and from the abuse secondary limit. Trips faster than REST because point costs per query. | Wait until the GraphQL window resets (typically ~1 hour from the first call in the window). REST still works — use fallbacks below. |
| `{"message":"API rate limit exceeded"}` | 429 | Primary rate limit — too many API calls per hour | Wait until `X-RateLimit-Reset` header timestamp |
**Prevention:** Add `sleep 3` between individual thread reply API calls. When posting >20 replies, increase to `sleep 5`.
### Detection
The `gh` CLI surfaces the GraphQL limit on stderr with the exact string `GraphQL: API rate limit already exceeded for user ID <id>` and exits 1 — any `gh api graphql ...` **or** `gh pr view ...` call fails. Check current quota and reset time via the REST endpoint that reports GraphQL quota (this call is REST and still works whether GraphQL is rate-limited OR fully down):
```bash
gh api rate_limit --jq '.resources.graphql' # { "limit": 5000, "used": 5000, "remaining": 0, "reset": 1729...}
# Human-readable reset:
gh api rate_limit --jq '.resources.graphql.reset' | xargs -I{} date -r {}
```
Retry when `remaining > 0`. If you need to proceed sooner, sleep 25 min and probe again — the limit is per user, not per machine, so other concurrent agents under the same token also consume it.
### What keeps working
When GraphQL is unavailable (rate-limited or outage):
- **Keeps working (REST):** top-level reviews fetch, conversation comments fetch, all inline-comment replies, CI status (`gh pr checks`), and the `gh api rate_limit` probe.
- **Degraded:** inline thread list — fall back to flat `/pulls/{N}/comments` REST, which drops thread grouping, `isResolved`, and Relay thread IDs. You still get comment bodies and the `databaseId` as `id`, enough to read and reply.
- **Blocked:** `gh pr view`, the `resolveReviewThread` mutation, and any new `gh api graphql` queries — wait for the quota to reset.
### Fall back to REST
**PR metadata reads** — `gh pr view` uses GraphQL under the hood; use the REST pulls endpoint instead, which returns the full PR object:
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.body' # == --json body
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.base.ref' # == --json baseRefName
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.mergeable' # == --json mergeable
```
Note: REST `mergeable` returns `true|false|null`; GraphQL returns `MERGEABLE|CONFLICTING|UNKNOWN`. The `null` case maps to `UNKNOWN` — treat it the same (still computing; poll again).
**Inline comments (flat list)** — no thread grouping or `isResolved`, but enough to read and reply:
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate \
| jq '[.[] | {id, path, line, user: .user.login, body: .body[:200], in_reply_to_id}]'
```
Use this degraded mode to make progress on the fix → reply loop, then return to GraphQL for `resolveReviewThread` once the rate limit resets.
**Replies** — already REST-native (`/pulls/{N}/comments/{ID}/replies`); no change needed, use the same command as the main flow.
**`resolveReviewThread`** — **no REST equivalent**; GitHub does not expose a REST endpoint for thread resolution. Queue the thread IDs needing resolution, wait for the GraphQL limit to reset, then run the resolve mutations in a batch (with `sleep 3` between calls, per the secondary-limit guidance).
### Recovery from secondary rate limit (403 abuse)
**Recovery from secondary rate limit (403):**
1. Stop all API writes immediately
2. Wait **2 minutes minimum** (not 60s — secondary limits are stricter)
3. Resume with `sleep 3` between each call
@@ -452,8 +397,6 @@ gh api graphql -f query='mutation { resolveReviewThread(input: {threadId: "THREA
**Never call this mutation before committing the fix.** The orchestrator will verify actual unresolved counts via GraphQL after you output `ORCHESTRATOR:DONE` — false resolutions will be caught and you will be re-briefed.
> `resolveReviewThread` is GraphQL-only — no REST equivalent. If GraphQL is rate-limited, see [GitHub rate limits](#github-rate-limits) for the queue-and-retry flow.
### Verify actual count before outputting ORCHESTRATOR:DONE
Before claiming "0 unresolved threads", always query GitHub directly — don't rely on your own bookkeeping. Paginate all pages — a single `first: 100` query misses threads beyond page 1:

View File

@@ -1,245 +0,0 @@
---
name: pr-polish
description: Alternate /pr-review and /pr-address on a PR until the PR is truly mergeable — no new review findings, zero unresolved inline threads, zero unaddressed top-level reviews or issue comments, all CI checks green, and two consecutive quiet polls after CI settles. Use when the user wants a PR polished to merge-ready without setting a fixed number of rounds.
user-invocable: true
argument-hint: "[PR number or URL] — if omitted, finds PR for current branch."
metadata:
author: autogpt-team
version: "1.0.0"
---
# PR Polish
**Goal.** Drive a PR to merge-ready by alternating `/pr-review` and `/pr-address` until **all** of the following hold:
1. The most recent `/pr-review` produces **zero new findings** (no new inline comments, no new top-level reviews with a non-empty body).
2. Every inline review thread reachable via GraphQL reports `isResolved: true`.
3. Every non-bot, non-author top-level review has been acknowledged (replied-to) OR resolved via a thread it spawned.
4. Every non-bot, non-author issue comment has been acknowledged (replied-to).
5. Every CI check is `conclusion: "success"` or `"skipped"` / `"neutral"` — none `"failure"` or still pending.
6. **Two consecutive post-CI polls** (≥60s apart) stay clean — no new threads, no new non-empty reviews, no new issue comments. Bots (coderabbitai, sentry, autogpt-reviewer) frequently post late after CI settles; a single green snapshot is not sufficient.
**Do not stop at a fixed number of rounds.** If round N introduces new comments, round N+1 is required. Cap at `_MAX_ROUNDS = 10` as a safety valve, but expect 25 in practice.
## TodoWrite
Before starting, write two todos so the user can see the loop progression:
- `Round {current}: /pr-review + /pr-address on PR #{N}` — current iteration.
- `Final polish polling: 2 consecutive clean polls, CI green, 0 unresolved` — runs after the last non-empty review round.
Update the `current` round counter at the start of each iteration; mark `completed` only when the round's address step finishes (all new threads addressed + resolved).
## Find the PR
```bash
ARG_PR="${ARG:-}"
# Normalize URL → numeric ID if the skill arg is a pull-request URL.
if [[ "$ARG_PR" =~ ^https?://github\.com/[^/]+/[^/]+/pull/([0-9]+) ]]; then
ARG_PR="${BASH_REMATCH[1]}"
fi
PR="${ARG_PR:-$(gh pr list --head "$(git branch --show-current)" --repo Significant-Gravitas/AutoGPT --json number --jq '.[0].number')}"
if [ -z "$PR" ] || [ "$PR" = "null" ]; then
echo "No PR found for current branch. Provide a PR number or URL as the skill arg."
exit 1
fi
echo "Polishing PR #$PR"
```
## The outer loop
```text
round = 0
while round < _MAX_ROUNDS:
round += 1
baseline = snapshot_state(PR) # see "Snapshotting state" below
invoke_skill("pr-review", PR) # posts findings as inline comments / top-level review
findings = diff_state(PR, baseline)
if findings.total == 0:
break # no new findings → go to polish polling
invoke_skill("pr-address", PR) # resolves every unresolved thread + CI failure
# Post-loop: polish polling (see below).
polish_polling(PR)
```
### Snapshotting state
Before each `/pr-review`, capture a baseline so the diff after the review reflects **only** what the review just added (not pre-existing threads):
```bash
# Inline threads — total count + latest databaseId per thread
gh api graphql -f query="
{
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
pullRequest(number: ${PR}) {
reviewThreads(first: 100) {
totalCount
nodes {
id
isResolved
comments(last: 1) { nodes { databaseId } }
}
}
}
}
}" > /tmp/baseline_threads.json
# Top-level reviews — count + latest id per non-empty review
gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}/reviews" --paginate \
--jq '[.[] | select((.body // "") != "") | {id, user: .user.login, state, submitted_at}]' \
> /tmp/baseline_reviews.json
# Issue comments — count + latest id per non-bot, non-author comment.
# Bots are filtered by User.type == "Bot" (GitHub sets this for app/bot
# accounts like coderabbitai, github-actions, sentry-io). The author is
# filtered by comparing login to the PR author — export it so jq can see it.
AUTHOR=$(gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}" --jq '.user.login')
gh api "repos/Significant-Gravitas/AutoGPT/issues/${PR}/comments" --paginate \
--jq --arg author "$AUTHOR" \
'[.[] | select(.user.type != "Bot" and .user.login != $author)
| {id, user: .user.login, created_at}]' \
> /tmp/baseline_issue_comments.json
```
### Diffing after a review
After `/pr-review` runs, any of these counting as "new findings" means another address round is needed:
- New inline thread `id` not in the baseline.
- An existing thread whose latest comment `databaseId` is higher than the baseline's (new reply on an old thread).
- A new top-level review `id` with a non-empty body.
- A new issue comment `id` from a non-bot, non-author user.
If any of the four buckets is non-empty → not done; invoke `/pr-address` and loop.
## Polish polling
Once `/pr-review` produces zero new findings, do **not** exit yet. Bots (coderabbitai, sentry, autogpt-reviewer) commonly post late reviews after CI settles — 3090 seconds after the final push. Poll at 60-second intervals:
```text
NON_SUCCESS_TERMINAL = {"failure", "cancelled", "timed_out", "action_required", "startup_failure"}
clean_polls = 0
required_clean = 2
while clean_polls < required_clean:
# 1. CI gate — any terminal non-success conclusion (not just "failure")
# must trigger /pr-address. "success", "skipped", "neutral" are clean;
# anything else (including cancelled, timed_out, action_required) is a
# blocker that won't self-resolve.
ci = fetch_check_runs(PR)
if any ci.conclusion in NON_SUCCESS_TERMINAL:
invoke_skill("pr-address", PR) # address failures + any new comments
baseline = snapshot_state(PR) # reset — push during address invalidates old baseline
clean_polls = 0
continue
if any ci.conclusion is None (still in_progress):
sleep 60; continue # wait without counting this as clean
# 2. Comment / thread gate
threads = fetch_unresolved_threads(PR)
new_issue_comments = diff_against_baseline(issue_comments)
new_reviews = diff_against_baseline(reviews)
if threads or new_issue_comments or new_reviews:
invoke_skill("pr-address", PR)
baseline = snapshot_state(PR) # reset — the address loop just dealt with these,
# otherwise they stay "new" relative to the old baseline forever
clean_polls = 0
continue
# 3. Mergeability gate
mergeable = gh api repos/.../pulls/${PR} --jq '.mergeable'
if mergeable == false (CONFLICTING):
resolve_conflicts(PR) # see pr-address skill
clean_polls = 0
continue
if mergeable is null (UNKNOWN):
sleep 60; continue
clean_polls += 1
sleep 60
```
Only after `clean_polls == 2` do you report `ORCHESTRATOR:DONE`.
### Why 2 clean polls, not 1
A single green snapshot can be misleading — the final CI check often completes ~30s before a bot posts its delayed review. One quiet cycle does not prove the PR is stable; two consecutive cycles with no new threads, reviews, or issue comments arriving gives high confidence nothing else is incoming.
### Why checking every source each poll
`/pr-address` polling inside a single round already re-checks its own comments, but `/pr-polish` sits a level above and must also catch:
- New top-level reviews (autogpt-reviewer sometimes posts structured feedback only after several CI green cycles).
- Issue comments from human reviewers (not caught by inline thread polling).
- Sentry bug predictions that land on new line numbers post-push.
- Merge conflicts introduced by a race between your push and a merge to `dev`.
## Invocation pattern
Delegate to existing skills with the `Skill` tool; do not re-implement the review or address logic inline. This keeps the polish loop focused on orchestration and lets the child skills evolve independently.
```python
Skill(skill="pr-review", args=pr_url)
Skill(skill="pr-address", args=pr_url)
```
After each child invocation, re-query GitHub state directly — never trust a summary for the stop condition. The orchestrator's `ORCHESTRATOR:DONE` is verified against actual GraphQL / REST responses per the rules in `pr-address`'s "Verify actual count before outputting ORCHESTRATOR:DONE" section.
### **Auto-continue: do NOT end your response between child skills**
`/pr-polish` is a single orchestration task — one invocation drives the PR all the way to merge-ready. When a child `Skill()` call returns control to you:
- Do NOT summarize and stop.
- Do NOT wait for user confirmation to continue.
- Immediately, in the same response, perform the next loop step: state diff → decide next action → next `Skill()` call or polling sleep.
The child skill returning is a **loop iteration boundary**, not a conversation turn boundary. You are expected to keep going until one of the exit conditions in the opening section is met (2 consecutive clean polls, `_MAX_ROUNDS` hit, or an unrecoverable error).
If the user needs to approve a risky action mid-loop (e.g., a force-push or a destructive git operation), pause there — but not at the routine "round N finished, round N+1 needed" boundary. Those are silent transitions.
## GitHub rate limits
This skill issues many GraphQL calls (one review-thread query per outer iteration plus per-poll queries inside polish polling). Expect the GraphQL budget to be tight on large PRs. When `gh api rate_limit --jq .resources.graphql.remaining` drops below ~200, back off:
- Fall back to REST for reads (flat `/pulls/{N}/comments`, `/pulls/{N}/reviews`, `/issues/{N}/comments`) per the `pr-address` skill's GraphQL-fallback section.
- Queue thread resolutions (GraphQL-only) until the budget resets; keep making progress on fixes + REST replies meanwhile.
- `sleep 5` between any batch of ≥20 writes to avoid secondary rate limits.
## Safety valves
- `_MAX_ROUNDS = 10` — if review+address rounds exceed this, stop and escalate to the user with a summary of what's still unresolved. A PR that cannot converge in 10 rounds has systemic issues that need human judgment.
- After each commit, run `poetry run format` / `pnpm format && pnpm lint && pnpm types` per the target codebase's conventions. A failing format check is CI `failure` that will never self-resolve.
- Every `/pr-review` round checks for **duplicate** concerns first (via `pr-review`'s own "Fetch existing review comments" step) so the loop does not re-post the same finding that a prior round already resolved.
## Reporting
When the skill finishes (either via two clean polls or hitting `_MAX_ROUNDS`), produce a compact summary:
```
PR #{N} polish complete ({rounds_completed} rounds):
- {X} inline threads opened and resolved
- {Y} CI failures fixed
- {Z} new commits pushed
Final state: CI green, {total} threads all resolved, mergeable.
```
If exiting via `_MAX_ROUNDS`, flag explicitly:
```
PR #{N} polish stopped at {_MAX_ROUNDS} rounds — NOT merge-ready:
- {N} threads still unresolved: {titles}
- CI status: {summary}
Needs human review.
```
## When to use this skill
Use when the user says any of:
- "polish this PR"
- "keep reviewing and addressing until it's mergeable"
- "loop /pr-review + /pr-address until done"
- "make sure the PR is actually merge-ready"
Do **not** use when:
- User wants just one review pass (→ `/pr-review`).
- User wants to address already-posted comments without further self-review (→ `/pr-address`).
- A fixed round count is explicitly requested (e.g., "do 3 rounds") — honour the count instead of converging.

View File

@@ -5,7 +5,7 @@ user-invocable: true
argument-hint: "[worktree path or PR number] — tests the PR in the given worktree. Optional flags: --fix (auto-fix issues found)"
metadata:
author: autogpt-team
version: "2.1.0"
version: "2.0.0"
---
# Manual E2E Test
@@ -180,120 +180,6 @@ Based on the PR analysis, write a test plan to `$RESULTS_DIR/test-plan.md`:
**Be critical** — include edge cases, error paths, and security checks. Every scenario MUST specify what screenshots to take and what state to verify.
## Step 3.0: Claim the testing lock (coordinate parallel agents)
Multiple worktrees share the same host — Docker infra (postgres, redis, clamav), app ports (3000/8006/…), and the test user. Two agents running `/pr-test` concurrently will corrupt each other's state (connection-pool exhaustion, port binds failing silently, cross-test assertions). Use the root-worktree lock file to take turns.
### Lock file contract
Path (**always** the root worktree so all siblings see it): `$REPO_ROOT/.ign.testing.lock`
Body (one `key=value` per line):
```
holder=<pr-XXXXX-purpose>
pid=<pid-or-"self">
started=<iso8601>
heartbeat=<iso8601, updated every ~2 min>
worktree=<full path>
branch=<branch name>
intent=<one-line description + rough duration>
```
### Claim
```bash
LOCK=$REPO_ROOT/.ign.testing.lock
NOW=$(date -u +%Y-%m-%dT%H:%MZ)
STALE_AFTER_MIN=5
if [ -f "$LOCK" ]; then
HB=$(grep '^heartbeat=' "$LOCK" | cut -d= -f2)
HB_EPOCH=$(date -j -f '%Y-%m-%dT%H:%MZ' "$HB" +%s 2>/dev/null || date -d "$HB" +%s 2>/dev/null || echo 0)
AGE_MIN=$(( ( $(date -u +%s) - HB_EPOCH ) / 60 ))
if [ "$AGE_MIN" -gt "$STALE_AFTER_MIN" ]; then
echo "WARN: stale lock (${AGE_MIN}m old) — reclaiming"
cat "$LOCK" | sed 's/^/ stale: /'
else
echo "Another agent holds the lock:"; cat "$LOCK"
echo "Wait until released or resume after $((STALE_AFTER_MIN - AGE_MIN))m."
exit 1
fi
fi
cat > "$LOCK" <<EOF
holder=pr-${PR_NUMBER}-e2e
pid=self
started=$NOW
heartbeat=$NOW
worktree=$WORKTREE_PATH
branch=$(cd $WORKTREE_PATH && git branch --show-current)
intent=E2E test PR #${PR_NUMBER}, native mode, ~60min
EOF
echo "Lock claimed"
```
### Heartbeat (MUST run in background during the whole test)
Without a heartbeat a crashed agent keeps the lock forever. Run this as a background process right after claim:
```bash
(while true; do
sleep 120
[ -f "$LOCK" ] || exit 0 # lock released → exit heartbeat
perl -i -pe "s/^heartbeat=.*/heartbeat=$(date -u +%Y-%m-%dT%H:%MZ)/" "$LOCK"
done) &
HEARTBEAT_PID=$!
echo "$HEARTBEAT_PID" > /tmp/pr-test-heartbeat.pid
```
### Release (always — even on failure)
```bash
kill "$HEARTBEAT_PID" 2>/dev/null
rm -f "$LOCK" /tmp/pr-test-heartbeat.pid
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock" \
>> $REPO_ROOT/.ign.testing.log
```
Use a `trap` so release runs even on `exit 1`:
```bash
trap 'kill "$HEARTBEAT_PID" 2>/dev/null; rm -f "$LOCK"' EXIT INT TERM
```
### **Release the lock AS SOON AS the test run is done**
The lock guards **test execution**, not **app lifecycle**. Once Step 5 (record results) and Step 6 (post PR comment) are complete, release the lock IMMEDIATELY — even if:
- The native `poetry run app` / `pnpm dev` processes are still running so the user can keep poking at the app manually.
- You're leaving docker containers up.
- You're tailing logs for a minute or two.
Keeping the lock held past the test run is the single most common way `/pr-test` stalls other agents. **The app staying up is orthogonal to the lock; don't conflate them.** Sibling worktrees running their own `/pr-test` will kill the stray processes and free the ports themselves (Step 3c/3e-native handle that) — they just need the lock file gone.
Concretely, the sequence at the end of every `/pr-test` run (success or failure) is:
```bash
# 1. Write the final report + post PR comment — done above in Step 5/6.
# 2. Release the lock right now, even if the app is still up.
kill "$HEARTBEAT_PID" 2>/dev/null
rm -f "$LOCK" /tmp/pr-test-heartbeat.pid
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock (app may still be running)" \
>> $REPO_ROOT/.ign.testing.log
# 3. Optionally leave the app running and note it so the user knows:
echo "Native stack still running on :3000 / :8006 for manual poking. Kill with:"
echo " pkill -9 -f 'poetry run app'; pkill -9 -f 'next-server|next dev'"
```
If a sibling agent's `/pr-test` needs to take over, it'll do the kill+rebuild dance from Step 3c/3e-native on its own — your only job is to not hold the lock file past the end of your test.
### Shared status log
`$REPO_ROOT/.ign.testing.log` is an append-only channel any agent can read/write. Use it for "I'm waiting", "I'm done, resources free", or post-run notes:
```bash
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] <message>" \
>> $REPO_ROOT/.ign.testing.log
```
## Step 3: Environment setup
### 3a. Copy .env files from the root worktree
@@ -362,87 +248,7 @@ docker ps --format "{{.Names}}" | grep -E "rest_server|executor|copilot|websocke
done
```
**Native mode also:** when running the app natively (see 3e-native), kill any stray host processes and free the app ports before starting — otherwise `poetry run app` and `pnpm dev` will fail to bind.
```bash
# Kill stray native app processes from prior runs
pkill -9 -f "python.*backend" 2>/dev/null || true
pkill -9 -f "poetry run app" 2>/dev/null || true
pkill -9 -f "next-server|next dev" 2>/dev/null || true
# Free app ports (errors per port are ignored — port may simply be unused)
for port in 3000 8006 8001 8002 8005 8008; do
lsof -ti :$port -sTCP:LISTEN | xargs -r kill -9 2>/dev/null || true
done
```
### 3e-native. Run the app natively (PREFERRED for iterative dev)
Native mode runs infra (postgres, supabase, redis, rabbitmq, clamav) in docker but runs the backend and frontend directly on the host. This avoids the 3-8 minute `docker compose build` cycle on every backend change — code edits are picked up on process restart (seconds) instead of a full image rebuild.
**When to prefer native mode (default for this skill):**
- Iterative dev/debug loops where you're editing backend or frontend code between test runs
- Any PR that touches Python/TS source but not Dockerfiles, compose config, or infra images
- Fast repro of a failing scenario — restart `poetry run app` in a couple of seconds
**When to prefer docker mode (3e fallback):**
- Testing changes to `Dockerfile`, `docker-compose.yml`, or base images
- Production-parity smoke tests (exact container env, networking, volumes)
- CI-equivalent runs where you need the exact image that'll ship
**Note on 3b (copilot auth):** no npm install anywhere. `poetry install` pulls in `claude_agent_sdk`, which ships its own Claude CLI binary — available on `PATH` whenever you run commands via `poetry run` (native) OR whenever the copilot_executor container is built from its Poetry lockfile (docker). The OAuth token extraction still applies (same `refresh_claude_token.sh` call).
**Preamble:** before starting native, run the kill-stray + free-ports block from 3c's "Native mode also" subsection.
**1. Start infra only (one-time per session):**
```bash
cd $PLATFORM_DIR && docker compose --profile local up deps --detach --remove-orphans --build
```
This brings up postgres/supabase/redis/rabbitmq/clamav and skips all app services.
**2. Start the backend natively:**
```bash
cd $BACKEND_DIR && (poetry run app 2>&1 | tee .ign.application.logs) &
```
`poetry run app` spawns **all** app subprocesses — `rest_server`, `executor`, `copilot_executor`, `websocket`, `scheduler`, `notification_server`, `database_manager` — inside ONE parent process. No separate containers, no separate terminals. The `.ign.application.logs` prefix is already gitignored.
**3. Wait for the backend on :8006 BEFORE starting the frontend.** This ordering matters — the frontend's `pnpm dev` startup invokes `generate-api-queries`, which fetches `/openapi.json` from the backend. If the backend isn't listening yet, `pnpm dev` fails immediately.
```bash
for i in $(seq 1 60); do
if [ "$(curl -s -o /dev/null -w '%{http_code}' http://localhost:8006/docs 2>/dev/null)" = "200" ]; then
echo "Backend ready"
break
fi
sleep 2
done
```
**4. Start the frontend natively:**
```bash
cd $FRONTEND_DIR && (pnpm dev 2>&1 | tee .ign.frontend.logs) &
```
**5. Wait for the frontend on :3000:**
```bash
for i in $(seq 1 60); do
if [ "$(curl -s -o /dev/null -w '%{http_code}' http://localhost:3000 2>/dev/null)" = "200" ]; then
echo "Frontend ready"
break
fi
sleep 2
done
```
Once both are up, skip 3e/3f and go straight to **3g/3h** (feature flags / test user creation).
### 3e. Build and start (docker — fallback)
### 3e. Build and start
```bash
cd $PLATFORM_DIR && docker compose build --no-cache 2>&1 | tail -20
@@ -636,22 +442,6 @@ agent-browser --session-name pr-test snapshot | grep "text:"
### Checking logs
**Native mode:** when running via `poetry run app` + `pnpm dev`, all app logs stream to the `.ign.*.logs` files written by the `tee` pipes in 3e-native. `rest_server`, `executor`, `copilot_executor`, `websocket`, `scheduler`, `notification_server`, and `database_manager` are all subprocesses of the single `poetry run app` parent, so their output is interleaved in `.ign.application.logs`.
```bash
# Backend (all app subprocesses interleaved)
tail -f $BACKEND_DIR/.ign.application.logs
# Frontend (Next.js dev server)
tail -f $FRONTEND_DIR/.ign.frontend.logs
# Filter for errors across either log
grep -iE "error|exception|traceback" $BACKEND_DIR/.ign.application.logs | tail -20
grep -iE "error|exception|traceback" $FRONTEND_DIR/.ign.frontend.logs | tail -20
```
**Docker mode:**
```bash
# Backend REST server
docker logs autogpt_platform-rest_server-1 2>&1 | tail -30
@@ -781,19 +571,6 @@ Upload screenshots to the PR using the GitHub Git API (no local git operations
**CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.** Every screenshot MUST appear as `![name](raw_url)` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes.
**CRITICAL — NEVER paste absolute local paths into the PR comment.** Strings like `/Users/…`, `/home/…`, `C:\…` are useless to every reviewer except you. Before posting, grep the final body for `/Users/`, `/home/`, `/tmp/`, `/private/`, `C:\`, `~/` and either drop those lines entirely or rewrite them as repo-relative paths (`autogpt_platform/backend/…`). The PR comment is an artifact reviewers on GitHub read — it must be self-contained on github.com. Keep local paths in `$RESULTS_DIR/test-report.md` for yourself; only copy the *content* they reference (excerpts, test names, log lines) into the PR comment, not the path.
**Pre-post sanity check** (paste after building the comment body, before `gh api ... comments`):
```bash
# Reject any local-looking absolute path or home-dir shortcut in the body
if grep -nE '(^|[^A-Za-z])(/Users/|/home/|/tmp/|/private/|C:\\|~/)[A-Za-z0-9]' "$COMMENT_FILE" ; then
echo "ABORT: local filesystem paths detected in PR comment body."
echo "Remove or rewrite as repo-relative (autogpt_platform/...) before posting."
exit 1
fi
```
```bash
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
REPO="Significant-Gravitas/AutoGPT"
@@ -1099,15 +876,9 @@ test scenario → find issue (bug OR UX problem) → screenshot broken state
### Problem: Frontend shows cookie banner blocking interaction
**Fix:** `agent-browser click 'text=Accept All'` before other interactions.
### Problem: Claude CLI not found in copilot_executor container
**Symptom:** Copilot logs say `claude: command not found` or similar when starting an SDK turn.
**Cause:** Image was built without `poetry install` (stale base layer, or Dockerfile bypass). The SDK CLI ships inside the `claude_agent_sdk` Poetry dep — it is NOT an npm package.
**Fix:** Rebuild the image cleanly: `docker compose build --no-cache copilot_executor && docker compose up -d copilot_executor`. Do NOT `docker exec ... npm install -g @anthropic-ai/claude-code` — that is outdated guidance and will pollute the container with a second CLI that the SDK won't use.
### Problem: agent-browser screenshot hangs / times out
**Symptom:** `agent-browser screenshot` exits with code 124 even on `about:blank`.
**Cause:** Stuck CDP connection or Chromium process tree. Seen on macOS when a prior `/pr-test` left a zombie Chrome for Testing.
**Fix:** `pkill -9 -f "agent-browser|chromium|Chrome for Testing" && sleep 2`, then reopen the browser with a fresh `--session-name`. If still failing, verify via `agent-browser eval` + `agent-browser snapshot` (DOM state) instead of relying on PNGs — the feature under test is the same.
### Problem: Container loses npm packages after rebuild
**Cause:** `docker compose up --build` rebuilds the image, losing runtime installs.
**Fix:** Add packages to the Dockerfile instead of installing at runtime.
### Problem: Services not starting after `docker compose up`
**Fix:** Wait and check health: `docker compose ps`. Common cause: migration hasn't finished. Check: `docker logs autogpt_platform-migrate-1 2>&1 | tail -5`. If supabase-db isn't healthy: `docker restart supabase-db && sleep 10`.

View File

@@ -48,15 +48,14 @@ git diff "$BASE_BRANCH"...HEAD -- src/ | head -500
For each changed file, determine:
1. **Is it a page?** (`page.tsx`) — these are the primary test targets
2. **Is it a hook?** (`use*.ts`) — test via the page/component that uses it; avoid direct `renderHook()` tests unless it is a shared reusable hook with standalone business logic
2. **Is it a hook?** (`use*.ts`) — test via the page that uses it
3. **Is it a component?** (`.tsx` in `components/`) — test via the parent page unless it's complex enough to warrant isolation
4. **Is it a helper?** (`helpers.ts`, `utils.ts`) — unit test directly if pure logic
**Priority order:**
1. Pages with new/changed data fetching or user interactions
2. Components with complex internal logic (modals, forms, wizards)
3. Shared hooks with standalone business logic when UI-level coverage is impractical
3. Hooks with non-trivial business logic
4. Pure helper functions
Skip: styling-only changes, type-only changes, config changes.
@@ -164,7 +163,6 @@ describe("LibraryPage", () => {
- Use `waitFor` when asserting side effects or state changes after interactions
- Import `fireEvent` or `userEvent` from the test-utils for interactions
- Do NOT mock internal hooks or functions — mock at the API boundary via MSW
- Prefer Orval-generated MSW handlers and response builders over hand-built API response objects
- Do NOT use `act()` manually — `render` and `fireEvent` handle it
- Keep tests focused: one behavior per test
- Use descriptive test names that read like sentences
@@ -192,7 +190,9 @@ import { http, HttpResponse } from "msw";
server.use(
http.get("http://localhost:3000/api/proxy/api/v2/library/agents", () => {
return HttpResponse.json({
agents: [{ id: "1", name: "Test Agent", description: "A test agent" }],
agents: [
{ id: "1", name: "Test Agent", description: "A test agent" },
],
pagination: { total_items: 1, total_pages: 1, page: 1, page_size: 10 },
});
}),
@@ -211,7 +211,6 @@ pnpm test:unit --reporter=verbose
```
If tests fail:
1. Read the error output carefully
2. Fix the test (not the source code, unless there is a genuine bug)
3. Re-run until all pass

View File

@@ -119,12 +119,10 @@ jobs:
runs-on: ubuntu-latest
services:
# Redis is provisioned as a real 3-shard cluster below via docker
# run (see the "Start Redis Cluster" step). GHA services can't
# override the image CMD or stand up multi-container clusters, so
# that setup is inlined — it mirrors the topology of the local dev
# compose stack (autogpt_platform/docker-compose.platform.yml) and
# prod helm chart.
redis:
image: redis:latest
ports:
- 6379:6379
rabbitmq:
image: rabbitmq:4.1.4
ports:
@@ -168,68 +166,6 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Start Redis Cluster (3 shards)
run: |
# 3-master Redis Cluster matching the local compose stack
# (autogpt_platform/docker-compose.platform.yml) and prod. Each
# shard runs in its own container on a dedicated bridge network,
# announces its compose-style hostname for intra-network clients,
# and publishes 1700N on the GHA host so tests can reach every
# shard via localhost. The backend's ``_address_remap`` rewrites
# every CLUSTER SLOTS reply to localhost:<announced-port>, which
# picks the right published port per shard.
#
# Not reusing docker-compose.platform.yml directly because compose
# validates the full file even when only some services are ``up``,
# and that file references services (db/kong/...) defined in a
# sibling compose file — pulling both in would needlessly couple
# CI to the full local-dev stack.
docker network create redis-cluster-ci
for i in 0 1 2; do
port=$((17000 + i))
bus=$((27000 + i))
docker run -d --name redis-$i --network redis-cluster-ci \
--network-alias redis-$i \
-p $port:$port \
redis:7 \
redis-server --port $port \
--cluster-enabled yes \
--cluster-config-file nodes.conf \
--cluster-node-timeout 5000 \
--cluster-require-full-coverage no \
--cluster-announce-hostname redis-$i \
--cluster-announce-port $port \
--cluster-announce-bus-port $bus \
--cluster-preferred-endpoint-type hostname
done
# Wait for each shard to accept commands.
for i in 0 1 2; do
port=$((17000 + i))
for _ in $(seq 1 30); do
docker exec redis-$i redis-cli -p $port ping 2>/dev/null | grep -q PONG && break
sleep 1
done
done
# Form the cluster from an init container on the same network so
# --cluster-preferred-endpoint-type hostname resolves redis-0/1/2.
docker run --rm --network redis-cluster-ci redis:7 \
redis-cli --cluster create \
redis-0:17000 redis-1:17001 redis-2:17002 \
--cluster-replicas 0 --cluster-yes
# Confirm convergence.
for _ in $(seq 1 30); do
state=$(docker exec redis-0 redis-cli -p 17000 cluster info | awk -F: '/^cluster_state:/ {print $2}' | tr -d '[:cntrl:]')
if [ "$state" = "ok" ]; then
echo "Redis Cluster ready (3 shards, state=ok)"
docker exec redis-0 redis-cli -p 17000 cluster nodes
exit 0
fi
sleep 1
done
echo "Redis Cluster failed to reach ok state" >&2
docker exec redis-0 redis-cli -p 17000 cluster info >&2 || true
exit 1
- name: Setup Supabase
uses: supabase/setup-cli@v1
with:
@@ -350,13 +286,8 @@ jobs:
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
REDIS_HOST: "localhost"
REDIS_PORT: "17000"
REDIS_PORT: "6379"
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
# Opt-in: lets backend/data/e2e_redis_restart_test.py spin up its
# own isolated 3-shard cluster (ports 2711027112) and exercise
# ``docker restart <shard>`` mid-stream. Off locally so a
# contributor's ``poetry run test`` doesn't pay the ~15s cost.
E2E_RESTART_ISOLATED: "1"
- name: Upload coverage reports to Codecov
if: ${{ !cancelled() }}

View File

@@ -160,7 +160,6 @@ jobs:
run: |
cp ../backend/.env.default ../backend/.env
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
echo "SCHEDULER_STARTUP_EMBEDDING_BACKFILL=false" >> ../backend/.env
env:
# Used by E2E test data script to generate embeddings for approved store agents
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -289,14 +288,6 @@ jobs:
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Set up tests - Cache Playwright browsers
uses: actions/cache@v5
with:
path: ~/.cache/ms-playwright
key: playwright-${{ runner.os }}-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
restore-keys: |
playwright-${{ runner.os }}-
- name: Copy source maps from Docker for E2E coverage
run: |
FRONTEND_CONTAINER=$(docker compose -f ../docker-compose.resolved.yml ps -q frontend)
@@ -308,8 +299,8 @@ jobs:
- name: Set up tests - Install browser 'chromium'
run: pnpm playwright install --with-deps chromium
- name: Run Playwright E2E suite
run: pnpm test:e2e:no-build
- name: Run Playwright tests
run: pnpm test:no-build
continue-on-error: false
- name: Upload E2E coverage to Codecov

6
.gitignore vendored
View File

@@ -187,7 +187,6 @@ autogpt_platform/backend/settings.py
.claude/settings.local.json
CLAUDE.local.md
/autogpt_platform/backend/logs
/autogpt_platform/backend/poetry.toml
# Test database
test.db
@@ -195,8 +194,3 @@ test.db
# Implementation plans (generated by AI agents)
plans/
.claude/worktrees/
test-results/
# Playwright MCP / local browser-testing artifacts
.playwright-mcp/
copilot-session-switch-qa/

View File

@@ -267,7 +267,7 @@
"filename": "autogpt_platform/backend/backend/blocks/replicate/replicate_block.py",
"hashed_secret": "8bbdd6f26368f58ea4011d13d7f763cb662e66f0",
"is_verified": false,
"line_number": 67
"line_number": 55
}
],
"autogpt_platform/backend/backend/blocks/slant3d/webhook.py": [
@@ -467,5 +467,5 @@
}
]
},
"generated_at": "2026-04-24T16:42:44Z"
"generated_at": "2026-04-09T14:20:23Z"
}

View File

@@ -1,6 +1,3 @@
*.ignore.*
*.ign.*
.application.logs
# Claude Code local settings only — the rest of .claude/ is shared (skills etc.)
.claude/settings.local.json

View File

@@ -0,0 +1,33 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class RateLimitSettings(BaseSettings):
redis_host: str = Field(
default="redis://localhost:6379",
description="Redis host",
validation_alias="REDIS_HOST",
)
redis_port: str = Field(
default="6379", description="Redis port", validation_alias="REDIS_PORT"
)
redis_password: Optional[str] = Field(
default=None,
description="Redis password",
validation_alias="REDIS_PASSWORD",
)
requests_per_minute: int = Field(
default=60,
description="Maximum number of requests allowed per minute per API key",
validation_alias="RATE_LIMIT_REQUESTS_PER_MINUTE",
)
model_config = SettingsConfigDict(case_sensitive=True, extra="ignore")
RATE_LIMIT_SETTINGS = RateLimitSettings()

View File

@@ -0,0 +1,51 @@
import time
from typing import Tuple
from redis import Redis
from .config import RATE_LIMIT_SETTINGS
class RateLimiter:
def __init__(
self,
redis_host: str = RATE_LIMIT_SETTINGS.redis_host,
redis_port: str = RATE_LIMIT_SETTINGS.redis_port,
redis_password: str | None = RATE_LIMIT_SETTINGS.redis_password,
requests_per_minute: int = RATE_LIMIT_SETTINGS.requests_per_minute,
):
self.redis = Redis(
host=redis_host,
port=int(redis_port),
password=redis_password,
decode_responses=True,
)
self.window = 60
self.max_requests = requests_per_minute
async def check_rate_limit(self, api_key_id: str) -> Tuple[bool, int, int]:
"""
Check if request is within rate limits.
Args:
api_key_id: The API key identifier to check
Returns:
Tuple of (is_allowed, remaining_requests, reset_time)
"""
now = time.time()
window_start = now - self.window
key = f"ratelimit:{api_key_id}:1min"
pipe = self.redis.pipeline()
pipe.zremrangebyscore(key, 0, window_start)
pipe.zadd(key, {str(now): now})
pipe.zcount(key, window_start, now)
pipe.expire(key, self.window)
_, _, request_count, _ = pipe.execute()
remaining = max(0, self.max_requests - request_count)
reset_time = int(now + self.window)
return request_count <= self.max_requests, remaining, reset_time

View File

@@ -0,0 +1,32 @@
from fastapi import HTTPException, Request
from starlette.middleware.base import RequestResponseEndpoint
from .limiter import RateLimiter
async def rate_limit_middleware(request: Request, call_next: RequestResponseEndpoint):
"""FastAPI middleware for rate limiting API requests."""
limiter = RateLimiter()
if not request.url.path.startswith("/api"):
return await call_next(request)
api_key = request.headers.get("Authorization")
if not api_key:
return await call_next(request)
api_key = api_key.replace("Bearer ", "")
is_allowed, remaining, reset_time = await limiter.check_rate_limit(api_key)
if not is_allowed:
raise HTTPException(
status_code=429, detail="Rate limit exceeded. Please try again later."
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limiter.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Reset"] = str(reset_time)
return response

View File

@@ -59,8 +59,6 @@ class OAuthState(BaseModel):
code_verifier: Optional[str] = None
scopes: list[str]
"""Unix timestamp (seconds) indicating when this OAuth state expires"""
credential_id: Optional[str] = None
"""If set, this OAuth flow upgrades an existing credential's scopes."""
class UserMetadata(BaseModel):

View File

@@ -1,16 +1,13 @@
import asyncio
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any
from expiringdict import ExpiringDict
if TYPE_CHECKING:
from redis.asyncio import Redis as AsyncRedis
from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster
from redis.asyncio.lock import Lock as AsyncRedisLock
AsyncRedisLike = Union[AsyncRedis, AsyncRedisCluster]
class AsyncRedisKeyedMutex:
"""
@@ -20,7 +17,7 @@ class AsyncRedisKeyedMutex:
in case the key is not unlocked for a specified duration, to prevent memory leaks.
"""
def __init__(self, redis: "AsyncRedisLike", timeout: int | None = 60):
def __init__(self, redis: "AsyncRedis", timeout: int | None = 60):
self.redis = redis
self.timeout = timeout
self.locks: dict[Any, "AsyncRedisLock"] = ExpiringDict(

View File

@@ -37,23 +37,6 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
# Web Push (VAPID) — generate with: poetry run python -c "
# from py_vapid import Vapid; import base64
# from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
# v = Vapid(); v.generate_keys()
# raw_priv = v.private_key.private_numbers().private_value.to_bytes(32, 'big')
# print('VAPID_PRIVATE_KEY=' + base64.urlsafe_b64encode(raw_priv).rstrip(b'=').decode())
# raw_pub = v.public_key.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint)
# print('VAPID_PUBLIC_KEY=' + base64.urlsafe_b64encode(raw_pub).rstrip(b'=').decode())
# "
# Dev-only keypair below — DO NOT use in staging/production. Regenerate
# your own with the snippet above before any non-local deployment.
VAPID_PRIVATE_KEY=17hBPdSdn6TR_yAgQxA0TjTcvRj3Lf6znHnASZ4rOKc
VAPID_PUBLIC_KEY=BBg49iVTWthVbRYphwmZNvZyiSJDqtSO4nmLxDzLKe3Oo9jbtu0Usa14xX4HQQNLUeiEfzD42zWSlrvY1PR12bs
# Per RFC 8292 push services use this in 410 Gone reports; set to a real
# mailbox in production. Defaults to a placeholder for local dev.
VAPID_CLAIM_EMAIL=mailto:dev@example.com
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
# Platform URLs (set these for webhooks and OAuth to work)
PLATFORM_BASE_URL=http://localhost:8000
@@ -77,8 +60,7 @@ NVIDIA_API_KEY=
# Graphiti Temporal Knowledge Graph Memory
# Rollout controlled by LaunchDarkly flag "graphiti-memory"
# LLM key falls back to CHAT_API_KEY (AutoPilot), then OPEN_ROUTER_API_KEY.
# Embedder key falls back to CHAT_OPENAI_API_KEY (AutoPilot), then OPENAI_API_KEY.
# LLM/embedder keys fall back to OPEN_ROUTER_API_KEY and OPENAI_API_KEY when empty.
GRAPHITI_FALKORDB_HOST=localhost
GRAPHITI_FALKORDB_PORT=6380
GRAPHITI_FALKORDB_PASSWORD=
@@ -196,13 +178,6 @@ MEM0_API_KEY=
OPENWEATHERMAP_API_KEY=
GOOGLE_MAPS_API_KEY=
# Platform Bot Linking
PLATFORM_LINK_BASE_URL=http://localhost:3000/link
# CoPilot chat-platform bridge (Discord/Telegram/Slack)
# Uses FRONTEND_BASE_URL (above) for link confirmation pages.
AUTOPILOT_BOT_DISCORD_TOKEN=
# Communication Services
DISCORD_BOT_TOKEN=
MEDIUM_API_KEY=

View File

@@ -1,166 +0,0 @@
{
"id": "858e2226-e047-4d19-a832-3be4a134d155",
"version": 2,
"is_active": true,
"name": "Calculator agent",
"description": "",
"instructions": null,
"recommended_schedule_cron": null,
"forked_from_id": null,
"forked_from_version": null,
"user_id": "",
"created_at": "2026-04-13T03:45:11.241Z",
"nodes": [
{
"id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
"block_id": "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
"input_default": {
"name": "Input",
"secret": false,
"advanced": false
},
"metadata": {
"position": {
"x": -188.2244873046875,
"y": 95
}
},
"input_links": [],
"output_links": [
{
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"source_name": "result",
"sink_name": "a",
"is_static": true
}
],
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
"graph_version": 2,
"webhook_id": null
},
{
"id": "65429c9e-a0c6-4032-a421-6899c394fa74",
"block_id": "363ae599-353e-4804-937e-b2ee3cef3da4",
"input_default": {
"name": "Output",
"secret": false,
"advanced": false,
"escape_html": false
},
"metadata": {
"position": {
"x": 825.198974609375,
"y": 123.75
}
},
"input_links": [
{
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
"source_name": "result",
"sink_name": "value",
"is_static": false
}
],
"output_links": [],
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
"graph_version": 2,
"webhook_id": null
},
{
"id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"block_id": "b1ab9b19-67a6-406d-abf5-2dba76d00c79",
"input_default": {
"b": 34,
"operation": "Add",
"round_result": false
},
"metadata": {
"position": {
"x": 323.0255126953125,
"y": 121.25
}
},
"input_links": [
{
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"source_name": "result",
"sink_name": "a",
"is_static": true
}
],
"output_links": [
{
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
"source_name": "result",
"sink_name": "value",
"is_static": false
}
],
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
"graph_version": 2,
"webhook_id": null
}
],
"links": [
{
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
"source_name": "result",
"sink_name": "value",
"is_static": false
},
{
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"source_name": "result",
"sink_name": "a",
"is_static": true
}
],
"sub_graphs": [],
"input_schema": {
"type": "object",
"properties": {
"Input": {
"advanced": false,
"secret": false,
"title": "Input"
}
},
"required": [
"Input"
]
},
"output_schema": {
"type": "object",
"properties": {
"Output": {
"advanced": false,
"secret": false,
"title": "Output"
}
},
"required": [
"Output"
]
},
"has_external_trigger": false,
"has_human_in_the_loop": false,
"has_sensitive_action": false,
"trigger_setup_info": null,
"credentials_input_schema": {
"type": "object",
"properties": {},
"required": []
}
}

View File

@@ -1,44 +1,14 @@
import asyncio
import json
import logging
import time
from typing import Awaitable, Callable, Dict, Optional, Set
from typing import Dict, Set
from fastapi import WebSocket, WebSocketDisconnect
from redis.asyncio import Redis as AsyncRedis
from redis.asyncio.client import PubSub as AsyncPubSub
from redis.exceptions import MovedError, RedisError, ResponseError
from starlette.websockets import WebSocketState
from fastapi import WebSocket
from backend.api.model import WSMessage, WSMethod
from backend.data import redis_client as redis
from backend.data.event_bus import _assert_no_wildcard
from backend.api.model import NotificationPayload, WSMessage, WSMethod
from backend.data.execution import (
ExecutionEventType,
exec_channel,
get_graph_execution_meta,
graph_all_channel,
GraphExecutionEvent,
NodeExecutionEvent,
)
from backend.data.notification_bus import NotificationEvent
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
_settings = Settings()
def _is_ws_close_race(exc: BaseException, websocket: WebSocket) -> bool:
"""A SPUBLISH→WS send racing with WS close — benign, drop quietly."""
if isinstance(exc, WebSocketDisconnect):
return True
if (
getattr(websocket, "application_state", None) == WebSocketState.DISCONNECTED
or getattr(websocket, "client_state", None) == WebSocketState.DISCONNECTED
):
return True
if isinstance(exc, RuntimeError) and "close message has been sent" in str(exc):
return True
return False
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
@@ -46,379 +16,128 @@ _EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
}
def event_bus_channel(channel_key: str) -> str:
"""Prefix a channel key with the execution event bus name."""
return f"{_settings.config.execution_event_bus_name}/{channel_key}"
def _notification_bus_channel(user_id: str) -> str:
"""Return the full sharded channel name for a user's notifications."""
return f"{_settings.config.notification_event_bus_name}/{user_id}"
MessageHandler = Callable[[Optional[bytes | str]], Awaitable[None]]
def _is_moved_error(exc: BaseException) -> bool:
"""A MOVED redirect — slot migration mid-stream; pump should reconnect."""
if isinstance(exc, MovedError):
return True
if isinstance(exc, ResponseError) and str(exc).startswith("MOVED "):
return True
return False
# Reconnect tunables for shard-failover during pubsub.listen().
_PUMP_RECONNECT_DEADLINE_S = 60.0
_PUMP_RECONNECT_BACKOFF_INITIAL_S = 0.5
_PUMP_RECONNECT_BACKOFF_MAX_S = 8.0
class _Subscription:
"""One SSUBSCRIBE lifecycle bound to a WebSocket, pinned to the owning shard."""
def __init__(self, full_channel: str) -> None:
_assert_no_wildcard(full_channel)
self.full_channel = full_channel
self._client: AsyncRedis | None = None
self._pubsub: AsyncPubSub | None = None
self._task: asyncio.Task | None = None
async def start(self, on_message: MessageHandler) -> None:
await self._open_pubsub()
self._task = asyncio.create_task(self._pump(on_message))
async def _open_pubsub(self) -> None:
"""(Re)establish the sharded pubsub connection + SSUBSCRIBE."""
self._client = await redis.connect_sharded_pubsub_async(self.full_channel)
self._pubsub = self._client.pubsub()
await self._pubsub.execute_command("SSUBSCRIBE", self.full_channel)
# redis-py 6.x async PubSub.listen() exits when ``channels`` is
# empty; raw SSUBSCRIBE doesn't populate it, so do it ourselves.
self._pubsub.channels[self.full_channel] = None # type: ignore[index]
async def _close_pubsub_quietly(self) -> None:
"""Best-effort teardown before reconnect — never raises."""
if self._pubsub is not None:
try:
await self._pubsub.aclose()
except Exception:
pass
self._pubsub = None
if self._client is not None:
try:
await self._client.aclose()
except Exception:
pass
self._client = None
async def _pump(self, on_message: MessageHandler) -> None:
if self._pubsub is None:
return
backoff = _PUMP_RECONNECT_BACKOFF_INITIAL_S
deadline = time.monotonic() + _PUMP_RECONNECT_DEADLINE_S
while True:
pubsub = self._pubsub
if pubsub is None:
return
needs_reconnect = False
try:
async for message in pubsub.listen():
msg_type = message.get("type")
# Server-pushed sunsubscribe: slot ownership changed and
# Redis revoked our SSUBSCRIBE without dropping the TCP.
# Treat as a reconnect trigger so we re-resolve the shard.
if msg_type == "sunsubscribe":
needs_reconnect = True
break
if msg_type not in ("smessage", "message", "pmessage"):
continue
# Successful read resets the reconnect budget.
backoff = _PUMP_RECONNECT_BACKOFF_INITIAL_S
deadline = time.monotonic() + _PUMP_RECONNECT_DEADLINE_S
try:
await on_message(message.get("data"))
except Exception:
logger.exception(
"Websocket message-handler failed for channel %s",
self.full_channel,
)
if not needs_reconnect:
# listen() exited cleanly (channels emptied) — pump is done.
return
except asyncio.CancelledError:
raise
except (ConnectionError, RedisError) as exc:
if isinstance(exc, ResponseError) and not _is_moved_error(exc):
logger.exception(
"Pubsub pump crashed on non-retryable ResponseError for %s",
self.full_channel,
)
return
if time.monotonic() > deadline:
logger.exception(
"Pubsub pump giving up after reconnect deadline for %s",
self.full_channel,
)
return
logger.warning(
"Pubsub pump reconnecting for %s after %s: %s",
self.full_channel,
type(exc).__name__,
exc,
)
except Exception:
logger.exception("Pubsub pump crashed for %s", self.full_channel)
return
# Either a retryable error was raised, or the server pushed a
# sunsubscribe — close the stale pubsub and reopen against the
# (possibly migrated) shard.
await self._close_pubsub_quietly()
await asyncio.sleep(backoff)
backoff = min(backoff * 2, _PUMP_RECONNECT_BACKOFF_MAX_S)
try:
await self._open_pubsub()
except (ConnectionError, RedisError) as reopen_exc:
logger.warning(
"Pubsub pump reopen failed for %s: %s",
self.full_channel,
reopen_exc,
)
# Loop again — deadline check will eventually exit.
continue
async def stop(self) -> None:
if self._task is not None:
self._task.cancel()
try:
await self._task
except (asyncio.CancelledError, Exception):
pass
self._task = None
if self._pubsub is not None:
try:
await self._pubsub.execute_command("SUNSUBSCRIBE", self.full_channel)
except Exception:
logger.warning(
"SUNSUBSCRIBE failed for %s", self.full_channel, exc_info=True
)
try:
await self._pubsub.aclose()
except Exception:
pass
self._pubsub = None
if self._client is not None:
try:
await self._client.aclose()
except Exception:
pass
self._client = None
class ConnectionManager:
def __init__(self):
self.active_connections: Set[WebSocket] = set()
# channel_key → sockets subscribed (public channel keys, not raw Redis channels)
self.subscriptions: Dict[str, Set[WebSocket]] = {}
# websocket → {channel_key: _Subscription}
self._ws_subs: Dict[WebSocket, Dict[str, _Subscription]] = {}
# websocket → notification subscription
self._ws_notifications: Dict[WebSocket, _Subscription] = {}
self.user_connections: Dict[str, Set[WebSocket]] = {}
async def connect_socket(self, websocket: WebSocket, *, user_id: str):
await websocket.accept()
self.active_connections.add(websocket)
self._ws_subs.setdefault(websocket, {})
await self._start_notification_subscription(websocket, user_id=user_id)
if user_id not in self.user_connections:
self.user_connections[user_id] = set()
self.user_connections[user_id].add(websocket)
async def disconnect_socket(self, websocket: WebSocket, *, user_id: str):
def disconnect_socket(self, websocket: WebSocket, *, user_id: str):
self.active_connections.discard(websocket)
# Stop SSUBSCRIBE pumps before dropping bookkeeping to avoid leaks.
subs = self._ws_subs.pop(websocket, {})
for sub in subs.values():
await sub.stop()
notif_sub = self._ws_notifications.pop(websocket, None)
if notif_sub is not None:
await notif_sub.stop()
for channel_key, subscribers in list(self.subscriptions.items()):
for subscribers in self.subscriptions.values():
subscribers.discard(websocket)
if not subscribers:
self.subscriptions.pop(channel_key, None)
user_conns = self.user_connections.get(user_id)
if user_conns is not None:
user_conns.discard(websocket)
if not user_conns:
self.user_connections.pop(user_id, None)
async def subscribe_graph_exec(
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
) -> str:
# Hash-tagged channel needs graph_id; resolve once per subscribe.
meta = await get_graph_execution_meta(user_id, graph_exec_id)
if meta is None:
raise ValueError(
f"graph_exec #{graph_exec_id} not found for user #{user_id}"
)
channel_key = graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id)
full_channel = event_bus_channel(
exec_channel(user_id, meta.graph_id, graph_exec_id)
return await self._subscribe(
_graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id), websocket
)
await self._open_subscription(websocket, channel_key, full_channel)
return channel_key
async def subscribe_graph_execs(
self, *, user_id: str, graph_id: str, websocket: WebSocket
) -> str:
channel_key = _graph_execs_channel_key(user_id, graph_id=graph_id)
full_channel = event_bus_channel(graph_all_channel(user_id, graph_id))
await self._open_subscription(websocket, channel_key, full_channel)
return channel_key
return await self._subscribe(
_graph_execs_channel_key(user_id, graph_id=graph_id), websocket
)
async def unsubscribe_graph_exec(
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
) -> str | None:
channel_key = graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id)
return await self._close_subscription(websocket, channel_key)
return await self._unsubscribe(
_graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id), websocket
)
async def unsubscribe_graph_execs(
self, *, user_id: str, graph_id: str, websocket: WebSocket
) -> str | None:
channel_key = _graph_execs_channel_key(user_id, graph_id=graph_id)
return await self._close_subscription(websocket, channel_key)
return await self._unsubscribe(
_graph_execs_channel_key(user_id, graph_id=graph_id), websocket
)
async def _open_subscription(
self, websocket: WebSocket, channel_key: str, full_channel: str
) -> None:
self.subscriptions.setdefault(channel_key, set()).add(websocket)
per_ws = self._ws_subs.setdefault(websocket, {})
if channel_key in per_ws:
return
sub = _Subscription(full_channel)
async def send_execution_update(
self, exec_event: GraphExecutionEvent | NodeExecutionEvent
) -> int:
graph_exec_id = (
exec_event.id
if isinstance(exec_event, GraphExecutionEvent)
else exec_event.graph_exec_id
)
async def on_message(data: Optional[bytes | str]) -> None:
await self._forward_exec_event(websocket, channel_key, data)
n_sent = 0
await sub.start(on_message)
per_ws[channel_key] = sub
async def _close_subscription(
self, websocket: WebSocket, channel_key: str
) -> str | None:
subscribers = self.subscriptions.get(channel_key)
if subscribers is None:
return None
subscribers.discard(websocket)
if not subscribers:
self.subscriptions.pop(channel_key, None)
per_ws = self._ws_subs.get(websocket)
if per_ws and channel_key in per_ws:
sub = per_ws.pop(channel_key)
await sub.stop()
return channel_key
async def _forward_exec_event(
self,
websocket: WebSocket,
channel_key: str,
raw_payload: Optional[bytes | str],
) -> None:
if raw_payload is None:
return
# Unwrap the `_EventPayloadWrapper` envelope, then re-wrap as a WS message.
try:
wrapper = (
raw_payload.decode()
if isinstance(raw_payload, (bytes, bytearray))
else raw_payload
channels: set[str] = {
# Send update to listeners for this graph execution
_graph_exec_channel_key(exec_event.user_id, graph_exec_id=graph_exec_id)
}
if isinstance(exec_event, GraphExecutionEvent):
# Send update to listeners for all executions of this graph
channels.add(
_graph_execs_channel_key(
exec_event.user_id, graph_id=exec_event.graph_id
)
)
except Exception:
logger.warning(
"Failed to decode pubsub payload on %s", channel_key, exc_info=True
)
return
try:
parsed = json.loads(wrapper)
event_data = parsed.get("payload")
if not isinstance(event_data, dict):
return
event_type = event_data.get("event_type")
method = _EVENT_TYPE_TO_METHOD_MAP.get(ExecutionEventType(event_type))
if method is None:
return
for channel in channels.intersection(self.subscriptions.keys()):
message = WSMessage(
method=method,
channel=channel_key,
data=event_data,
method=_EVENT_TYPE_TO_METHOD_MAP[exec_event.event_type],
channel=channel,
data=exec_event.model_dump(),
).model_dump_json()
await websocket.send_text(message)
except Exception as e:
if _is_ws_close_race(e, websocket):
logger.debug("Dropped exec event on closed WS for %s", channel_key)
return
logger.exception("Failed to forward exec event on %s", channel_key)
for connection in self.subscriptions[channel]:
await connection.send_text(message)
n_sent += 1
async def _start_notification_subscription(
self, websocket: WebSocket, *, user_id: str
) -> None:
full_channel = _notification_bus_channel(user_id)
sub = _Subscription(full_channel)
return n_sent
async def on_message(data: Optional[bytes | str]) -> None:
await self._forward_notification(websocket, user_id, data)
try:
await sub.start(on_message)
except Exception:
logger.exception(
"Failed to open notification SSUBSCRIBE for user=%s", user_id
)
return
self._ws_notifications[websocket] = sub
async def _forward_notification(
self,
websocket: WebSocket,
user_id: str,
raw_payload: Optional[bytes | str],
) -> None:
if raw_payload is None:
return
try:
wrapper_json = (
raw_payload.decode()
if isinstance(raw_payload, (bytes, bytearray))
else raw_payload
)
parsed = json.loads(wrapper_json)
inner = parsed.get("payload") if isinstance(parsed, dict) else None
if not isinstance(inner, dict):
return
event = NotificationEvent.model_validate(inner)
except Exception:
logger.warning(
"Failed to parse notification payload for user=%s",
user_id,
exc_info=True,
)
return
# Defense in depth against cross-user payloads.
if event.user_id != user_id:
return
async def send_notification(
self, *, user_id: str, payload: NotificationPayload
) -> int:
"""Send a notification to all websocket connections belonging to a user."""
message = WSMessage(
method=WSMethod.NOTIFICATION,
data=event.payload.model_dump(),
data=payload.model_dump(),
).model_dump_json()
try:
await websocket.send_text(message)
except Exception as e:
if _is_ws_close_race(e, websocket):
logger.debug("Dropped notification on closed WS for user=%s", user_id)
return
logger.warning(
"Failed to deliver notification to WS for user=%s",
user_id,
exc_info=True,
)
connections = tuple(self.user_connections.get(user_id, set()))
if not connections:
return 0
await asyncio.gather(
*(connection.send_text(message) for connection in connections),
return_exceptions=True,
)
return len(connections)
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
if channel_key not in self.subscriptions:
self.subscriptions[channel_key] = set()
self.subscriptions[channel_key].add(websocket)
return channel_key
async def _unsubscribe(self, channel_key: str, websocket: WebSocket) -> str | None:
if channel_key in self.subscriptions:
self.subscriptions[channel_key].discard(websocket)
if not self.subscriptions[channel_key]:
del self.subscriptions[channel_key]
return channel_key
return None
def graph_exec_channel_key(user_id: str, *, graph_exec_id: str) -> str:
def _graph_exec_channel_key(user_id: str, *, graph_exec_id: str) -> str:
return f"{user_id}|graph_exec#{graph_exec_id}"

View File

@@ -1,386 +0,0 @@
"""ConnectionManager integration over the live 3-shard Redis cluster:
SSUBSCRIBE → SPUBLISH → WebSocket forwarding with no Redis mocks. Skips
when the cluster is unreachable."""
import asyncio
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock
from uuid import uuid4
import pytest
from fastapi import WebSocket
import backend.data.redis_client as redis_client
from backend.api.conn_manager import (
ConnectionManager,
_graph_execs_channel_key,
event_bus_channel,
graph_exec_channel_key,
)
from backend.api.model import WSMethod
from backend.data.execution import (
ExecutionStatus,
GraphExecutionEvent,
GraphExecutionMeta,
NodeExecutionEvent,
exec_channel,
graph_all_channel,
)
def _has_live_cluster() -> bool:
try:
c = redis_client.connect()
except Exception: # noqa: BLE001 — any connect failure → skip
return False
try:
c.close()
except Exception:
pass
return True
pytestmark = pytest.mark.skipif(
not _has_live_cluster(),
reason="local redis cluster not reachable; skip conn_manager integration",
)
def _meta(user_id: str, graph_id: str, graph_exec_id: str) -> GraphExecutionMeta:
"""Build a minimal GraphExecutionMeta for ``subscribe_graph_exec`` to use."""
return GraphExecutionMeta(
id=graph_exec_id,
user_id=user_id,
graph_id=graph_id,
graph_version=1,
inputs=None,
credential_inputs=None,
nodes_input_masks=None,
preset_id=None,
status=ExecutionStatus.RUNNING,
started_at=datetime.now(tz=timezone.utc),
ended_at=None,
stats=GraphExecutionMeta.Stats(),
)
def _node_event_payload(
*, user_id: str, graph_id: str, graph_exec_id: str, marker: str
) -> bytes:
"""Wire-format a NodeExecutionEvent the way RedisExecutionEventBus would."""
inner = NodeExecutionEvent(
user_id=user_id,
graph_id=graph_id,
graph_version=1,
graph_exec_id=graph_exec_id,
node_exec_id=f"node-exec-{marker}",
node_id="node-1",
block_id="block-1",
status=ExecutionStatus.COMPLETED,
input_data={"in": marker},
output_data={"out": [marker]},
add_time=datetime.now(tz=timezone.utc),
queue_time=None,
start_time=datetime.now(tz=timezone.utc),
end_time=datetime.now(tz=timezone.utc),
).model_dump(mode="json")
return json.dumps({"payload": inner}).encode()
def _graph_event_payload(
*, user_id: str, graph_id: str, graph_exec_id: str, marker: str
) -> bytes:
inner = GraphExecutionEvent(
id=graph_exec_id,
user_id=user_id,
graph_id=graph_id,
graph_version=1,
preset_id=None,
status=ExecutionStatus.COMPLETED,
started_at=datetime.now(tz=timezone.utc),
ended_at=datetime.now(tz=timezone.utc),
stats=GraphExecutionEvent.Stats(
cost=0,
duration=1.0,
node_exec_time=0.5,
node_exec_count=1,
),
inputs={"x": marker},
credential_inputs=None,
nodes_input_masks=None,
outputs={"y": [marker]},
).model_dump(mode="json")
return json.dumps({"payload": inner}).encode()
async def _wait_until(predicate, timeout: float = 5.0, interval: float = 0.05) -> bool:
"""Poll ``predicate()`` until truthy or timeout — used to wait for pubsub."""
deadline = asyncio.get_event_loop().time() + timeout
while asyncio.get_event_loop().time() < deadline:
if predicate():
return True
await asyncio.sleep(interval)
return False
@pytest.mark.asyncio
async def test_two_clients_get_independent_ssubscribes_on_right_shards(
monkeypatch,
) -> None:
"""Two WS clients on different graph_exec_ids each receive ONLY their
own publish, even when the channels land on different shards."""
user_id = "user-conn-int-1"
graph_a = f"graph-a-{uuid4().hex[:8]}"
graph_b = f"graph-b-{uuid4().hex[:8]}"
exec_a = f"exec-a-{uuid4().hex[:8]}"
exec_b = f"exec-b-{uuid4().hex[:8]}"
# Stub Prisma lookup so tests don't need a DB.
async def _fake_meta(_uid, gex_id):
return _meta(user_id, graph_a if gex_id == exec_a else graph_b, gex_id)
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
cm = ConnectionManager()
ws_a: AsyncMock = AsyncMock(spec=WebSocket)
ws_b: AsyncMock = AsyncMock(spec=WebSocket)
sent_a: list[str] = []
sent_b: list[str] = []
ws_a.send_text = AsyncMock(side_effect=lambda m: sent_a.append(m))
ws_b.send_text = AsyncMock(side_effect=lambda m: sent_b.append(m))
redis_client.get_redis.cache_clear()
cluster = redis_client.get_redis()
try:
await cm.subscribe_graph_exec(
user_id=user_id, graph_exec_id=exec_a, websocket=ws_a
)
await cm.subscribe_graph_exec(
user_id=user_id, graph_exec_id=exec_b, websocket=ws_b
)
# Let SSUBSCRIBE settle on each shard.
await asyncio.sleep(0.2)
# Publish to each per-exec channel.
chan_a = event_bus_channel(exec_channel(user_id, graph_a, exec_a))
chan_b = event_bus_channel(exec_channel(user_id, graph_b, exec_b))
cluster.spublish(
chan_a,
_node_event_payload(
user_id=user_id,
graph_id=graph_a,
graph_exec_id=exec_a,
marker="A",
).decode(),
)
cluster.spublish(
chan_b,
_node_event_payload(
user_id=user_id,
graph_id=graph_b,
graph_exec_id=exec_b,
marker="B",
).decode(),
)
delivered = await _wait_until(lambda: sent_a and sent_b, timeout=5.0)
assert delivered, f"timeout: sent_a={sent_a!r} sent_b={sent_b!r}"
msg_a = json.loads(sent_a[0])
msg_b = json.loads(sent_b[0])
assert msg_a["channel"] == graph_exec_channel_key(user_id, graph_exec_id=exec_a)
assert msg_b["channel"] == graph_exec_channel_key(user_id, graph_exec_id=exec_b)
assert msg_a["data"]["graph_exec_id"] == exec_a
assert msg_b["data"]["graph_exec_id"] == exec_b
# No cross-talk: each socket got exactly one message.
assert len(sent_a) == 1 and len(sent_b) == 1
finally:
await cm.disconnect_socket(ws_a, user_id=user_id)
await cm.disconnect_socket(ws_b, user_id=user_id)
redis_client.disconnect()
@pytest.mark.asyncio
async def test_aggregate_channel_receives_per_exec_publishes(monkeypatch) -> None:
"""A subscriber on the ``graph_execs`` aggregate channel must receive the
GraphExecutionEvent published to the ``/all`` channel — even though
per-exec events go to a different channel."""
user_id = "user-conn-int-2"
graph_id = f"graph-{uuid4().hex[:8]}"
exec_id = f"exec-{uuid4().hex[:8]}"
async def _fake_meta(_uid, gex_id):
return _meta(user_id, graph_id, gex_id)
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
cm = ConnectionManager()
ws_agg: AsyncMock = AsyncMock(spec=WebSocket)
ws_per: AsyncMock = AsyncMock(spec=WebSocket)
sent_agg: list[str] = []
sent_per: list[str] = []
ws_agg.send_text = AsyncMock(side_effect=lambda m: sent_agg.append(m))
ws_per.send_text = AsyncMock(side_effect=lambda m: sent_per.append(m))
redis_client.get_redis.cache_clear()
cluster = redis_client.get_redis()
try:
await cm.subscribe_graph_execs(
user_id=user_id, graph_id=graph_id, websocket=ws_agg
)
await cm.subscribe_graph_exec(
user_id=user_id, graph_exec_id=exec_id, websocket=ws_per
)
await asyncio.sleep(0.2)
# The eventbus publishes the same event to both channels — replicate.
chan_per = event_bus_channel(exec_channel(user_id, graph_id, exec_id))
chan_all = event_bus_channel(graph_all_channel(user_id, graph_id))
payload = _graph_event_payload(
user_id=user_id,
graph_id=graph_id,
graph_exec_id=exec_id,
marker="agg",
).decode()
cluster.spublish(chan_per, payload)
cluster.spublish(chan_all, payload)
delivered = await _wait_until(lambda: sent_agg and sent_per, timeout=5.0)
assert delivered, f"sent_agg={sent_agg!r} sent_per={sent_per!r}"
agg_msg = json.loads(sent_agg[0])
per_msg = json.loads(sent_per[0])
# Aggregate subscriber's channel key is the per-graph executions key.
assert agg_msg["channel"] == _graph_execs_channel_key(
user_id, graph_id=graph_id
)
assert per_msg["channel"] == graph_exec_channel_key(
user_id, graph_exec_id=exec_id
)
assert agg_msg["method"] == WSMethod.GRAPH_EXECUTION_EVENT.value
finally:
await cm.disconnect_socket(ws_agg, user_id=user_id)
await cm.disconnect_socket(ws_per, user_id=user_id)
redis_client.disconnect()
@pytest.mark.asyncio
async def test_disconnect_unsubscribes_and_drops_future_publishes(monkeypatch) -> None:
"""After ``disconnect_socket`` runs, a subsequent SPUBLISH must NOT reach
the dead websocket — exercises the SUNSUBSCRIBE + bookkeeping cleanup."""
user_id = "user-conn-int-3"
graph_id = f"graph-{uuid4().hex[:8]}"
exec_id = f"exec-{uuid4().hex[:8]}"
async def _fake_meta(_uid, gex_id):
return _meta(user_id, graph_id, gex_id)
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
cm = ConnectionManager()
ws: AsyncMock = AsyncMock(spec=WebSocket)
sent: list[str] = []
ws.send_text = AsyncMock(side_effect=lambda m: sent.append(m))
redis_client.get_redis.cache_clear()
cluster = redis_client.get_redis()
chan = event_bus_channel(exec_channel(user_id, graph_id, exec_id))
payload = _node_event_payload(
user_id=user_id, graph_id=graph_id, graph_exec_id=exec_id, marker="live"
).decode()
try:
await cm.subscribe_graph_exec(
user_id=user_id, graph_exec_id=exec_id, websocket=ws
)
await asyncio.sleep(0.15)
# First publish — must reach the socket.
cluster.spublish(chan, payload)
delivered = await _wait_until(lambda: bool(sent), timeout=5.0)
assert delivered
assert len(sent) == 1
# Disconnect → SUNSUBSCRIBE + bookkeeping cleared.
await cm.disconnect_socket(ws, user_id=user_id)
# Pump cancellation may drain in-flight messages; wait for it.
await asyncio.sleep(0.2)
# Channel bookkeeping must be gone.
assert (
graph_exec_channel_key(user_id, graph_exec_id=exec_id)
not in cm.subscriptions
)
assert ws not in cm._ws_subs
# Second publish — must NOT reach the (already-disconnected) socket.
cluster.spublish(
chan,
_node_event_payload(
user_id=user_id,
graph_id=graph_id,
graph_exec_id=exec_id,
marker="post-disconnect",
).decode(),
)
await asyncio.sleep(0.5)
# Still only the one pre-disconnect message.
assert len(sent) == 1
finally:
redis_client.disconnect()
@pytest.mark.asyncio
async def test_slow_consumer_receives_all_events_without_loss(monkeypatch) -> None:
"""Burst-publish many SPUBLISHes; assert every one reaches the subscriber
in order — guards against drops/reorderings in the pubsub pump."""
user_id = "user-conn-int-4"
graph_id = f"graph-{uuid4().hex[:8]}"
exec_id = f"exec-{uuid4().hex[:8]}"
n_events = 100
async def _fake_meta(_uid, gex_id):
return _meta(user_id, graph_id, gex_id)
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
cm = ConnectionManager()
ws: AsyncMock = AsyncMock(spec=WebSocket)
sent: list[str] = []
ws.send_text = AsyncMock(side_effect=lambda m: sent.append(m))
redis_client.get_redis.cache_clear()
cluster = redis_client.get_redis()
chan = event_bus_channel(exec_channel(user_id, graph_id, exec_id))
try:
await cm.subscribe_graph_exec(
user_id=user_id, graph_exec_id=exec_id, websocket=ws
)
await asyncio.sleep(0.2)
# Burst-publish n_events without yielding to the pump.
for i in range(n_events):
cluster.spublish(
chan,
_node_event_payload(
user_id=user_id,
graph_id=graph_id,
graph_exec_id=exec_id,
marker=f"m{i}",
).decode(),
)
delivered = await _wait_until(
lambda: len(sent) >= n_events, timeout=15.0, interval=0.1
)
assert delivered, f"only delivered {len(sent)}/{n_events}"
# Validate ordering — Redis pub/sub is FIFO per channel.
markers = [json.loads(m)["data"]["input_data"]["in"] for m in sent[:n_events]]
assert markers == [f"m{i}" for i in range(n_events)]
finally:
await cm.disconnect_socket(ws, user_id=user_id)
redis_client.disconnect()

File diff suppressed because it is too large Load Diff

View File

@@ -1,932 +0,0 @@
import asyncio
import logging
from typing import List
from autogpt_libs.auth import requires_admin_user
from autogpt_libs.auth.models import User as AuthUser
from fastapi import APIRouter, HTTPException, Security
from prisma.enums import AgentExecutionStatus
from pydantic import BaseModel
from backend.api.features.admin.model import (
AgentDiagnosticsResponse,
ExecutionDiagnosticsResponse,
)
from backend.data.diagnostics import (
FailedExecutionDetail,
OrphanedScheduleDetail,
RunningExecutionDetail,
ScheduleDetail,
ScheduleHealthMetrics,
cleanup_all_stuck_queued_executions,
cleanup_orphaned_executions_bulk,
cleanup_orphaned_schedules_bulk,
get_agent_diagnostics,
get_all_orphaned_execution_ids,
get_all_schedules_details,
get_all_stuck_queued_execution_ids,
get_execution_diagnostics,
get_failed_executions_count,
get_failed_executions_details,
get_invalid_executions_details,
get_long_running_executions_details,
get_orphaned_executions_details,
get_orphaned_schedules_details,
get_running_executions_details,
get_schedule_health_metrics,
get_stuck_queued_executions_details,
stop_all_long_running_executions,
)
from backend.data.execution import get_graph_executions
from backend.executor.utils import add_graph_execution, stop_graph_execution
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/admin",
tags=["diagnostics", "admin"],
dependencies=[Security(requires_admin_user)],
)
class RunningExecutionsListResponse(BaseModel):
"""Response model for list of running executions"""
executions: List[RunningExecutionDetail]
total: int
class FailedExecutionsListResponse(BaseModel):
"""Response model for list of failed executions"""
executions: List[FailedExecutionDetail]
total: int
class StopExecutionRequest(BaseModel):
"""Request model for stopping a single execution"""
execution_id: str
class StopExecutionsRequest(BaseModel):
"""Request model for stopping multiple executions"""
execution_ids: List[str]
class StopExecutionResponse(BaseModel):
"""Response model for stop execution operations"""
success: bool
stopped_count: int = 0
message: str
class RequeueExecutionResponse(BaseModel):
"""Response model for requeue execution operations"""
success: bool
requeued_count: int = 0
message: str
@router.get(
"/diagnostics/executions",
response_model=ExecutionDiagnosticsResponse,
summary="Get Execution Diagnostics",
)
async def get_execution_diagnostics_endpoint():
"""
Get comprehensive diagnostic information about execution status.
Returns all execution metrics including:
- Current state (running, queued)
- Orphaned executions (>24h old, likely not in executor)
- Failure metrics (1h, 24h, rate)
- Long-running detection (stuck >1h, >24h)
- Stuck queued detection
- Throughput metrics (completions/hour)
- RabbitMQ queue depths
"""
logger.info("Getting execution diagnostics")
diagnostics = await get_execution_diagnostics()
response = ExecutionDiagnosticsResponse(
running_executions=diagnostics.running_count,
queued_executions_db=diagnostics.queued_db_count,
queued_executions_rabbitmq=diagnostics.rabbitmq_queue_depth,
cancel_queue_depth=diagnostics.cancel_queue_depth,
orphaned_running=diagnostics.orphaned_running,
orphaned_queued=diagnostics.orphaned_queued,
failed_count_1h=diagnostics.failed_count_1h,
failed_count_24h=diagnostics.failed_count_24h,
failure_rate_24h=diagnostics.failure_rate_24h,
stuck_running_24h=diagnostics.stuck_running_24h,
stuck_running_1h=diagnostics.stuck_running_1h,
oldest_running_hours=diagnostics.oldest_running_hours,
stuck_queued_1h=diagnostics.stuck_queued_1h,
queued_never_started=diagnostics.queued_never_started,
invalid_queued_with_start=diagnostics.invalid_queued_with_start,
invalid_running_without_start=diagnostics.invalid_running_without_start,
completed_1h=diagnostics.completed_1h,
completed_24h=diagnostics.completed_24h,
throughput_per_hour=diagnostics.throughput_per_hour,
timestamp=diagnostics.timestamp,
)
logger.info(
f"Execution diagnostics: running={diagnostics.running_count}, "
f"queued_db={diagnostics.queued_db_count}, "
f"orphaned={diagnostics.orphaned_running + diagnostics.orphaned_queued}, "
f"failed_24h={diagnostics.failed_count_24h}"
)
return response
@router.get(
"/diagnostics/agents",
response_model=AgentDiagnosticsResponse,
summary="Get Agent Diagnostics",
)
async def get_agent_diagnostics_endpoint():
"""
Get diagnostic information about agents.
Returns:
- agents_with_active_executions: Number of unique agents with running/queued executions
- timestamp: Current timestamp
"""
logger.info("Getting agent diagnostics")
diagnostics = await get_agent_diagnostics()
response = AgentDiagnosticsResponse(
agents_with_active_executions=diagnostics.agents_with_active_executions,
timestamp=diagnostics.timestamp,
)
logger.info(
f"Agent diagnostics: with_active_executions={diagnostics.agents_with_active_executions}"
)
return response
@router.get(
"/diagnostics/executions/running",
response_model=RunningExecutionsListResponse,
summary="List Running Executions",
)
async def list_running_executions(
limit: int = 100,
offset: int = 0,
):
"""
Get detailed list of running and queued executions (recent, likely active).
Args:
limit: Maximum number of executions to return (default 100)
offset: Number of executions to skip (default 0)
Returns:
List of running executions with details
"""
logger.info(f"Listing running executions (limit={limit}, offset={offset})")
executions = await get_running_executions_details(limit=limit, offset=offset)
# Get total count for pagination
diagnostics = await get_execution_diagnostics()
total = diagnostics.running_count + diagnostics.queued_db_count
return RunningExecutionsListResponse(executions=executions, total=total)
@router.get(
"/diagnostics/executions/orphaned",
response_model=RunningExecutionsListResponse,
summary="List Orphaned Executions",
)
async def list_orphaned_executions(
limit: int = 100,
offset: int = 0,
):
"""
Get detailed list of orphaned executions (>24h old, likely not in executor).
Args:
limit: Maximum number of executions to return (default 100)
offset: Number of executions to skip (default 0)
Returns:
List of orphaned executions with details
"""
logger.info(f"Listing orphaned executions (limit={limit}, offset={offset})")
executions = await get_orphaned_executions_details(limit=limit, offset=offset)
# Get total count for pagination
diagnostics = await get_execution_diagnostics()
total = diagnostics.orphaned_running + diagnostics.orphaned_queued
return RunningExecutionsListResponse(executions=executions, total=total)
@router.get(
"/diagnostics/executions/failed",
response_model=FailedExecutionsListResponse,
summary="List Failed Executions",
)
async def list_failed_executions(
limit: int = 100,
offset: int = 0,
hours: int = 24,
):
"""
Get detailed list of failed executions.
Args:
limit: Maximum number of executions to return (default 100)
offset: Number of executions to skip (default 0)
hours: Number of hours to look back (default 24)
Returns:
List of failed executions with error details
"""
logger.info(
f"Listing failed executions (limit={limit}, offset={offset}, hours={hours})"
)
executions = await get_failed_executions_details(
limit=limit, offset=offset, hours=hours
)
# Get total count for pagination
# Always count actual total for given hours parameter
total = await get_failed_executions_count(hours=hours)
return FailedExecutionsListResponse(executions=executions, total=total)
@router.get(
"/diagnostics/executions/long-running",
response_model=RunningExecutionsListResponse,
summary="List Long-Running Executions",
)
async def list_long_running_executions(
limit: int = 100,
offset: int = 0,
):
"""
Get detailed list of long-running executions (RUNNING status >24h).
Args:
limit: Maximum number of executions to return (default 100)
offset: Number of executions to skip (default 0)
Returns:
List of long-running executions with details
"""
logger.info(f"Listing long-running executions (limit={limit}, offset={offset})")
executions = await get_long_running_executions_details(limit=limit, offset=offset)
# Get total count for pagination
diagnostics = await get_execution_diagnostics()
total = diagnostics.stuck_running_24h
return RunningExecutionsListResponse(executions=executions, total=total)
@router.get(
"/diagnostics/executions/stuck-queued",
response_model=RunningExecutionsListResponse,
summary="List Stuck Queued Executions",
)
async def list_stuck_queued_executions(
limit: int = 100,
offset: int = 0,
):
"""
Get detailed list of stuck queued executions (QUEUED >1h, never started).
Args:
limit: Maximum number of executions to return (default 100)
offset: Number of executions to skip (default 0)
Returns:
List of stuck queued executions with details
"""
logger.info(f"Listing stuck queued executions (limit={limit}, offset={offset})")
executions = await get_stuck_queued_executions_details(limit=limit, offset=offset)
# Get total count for pagination
diagnostics = await get_execution_diagnostics()
total = diagnostics.stuck_queued_1h
return RunningExecutionsListResponse(executions=executions, total=total)
@router.get(
"/diagnostics/executions/invalid",
response_model=RunningExecutionsListResponse,
summary="List Invalid Executions",
)
async def list_invalid_executions(
limit: int = 100,
offset: int = 0,
):
"""
Get detailed list of executions in invalid states (READ-ONLY).
Invalid states indicate data corruption and require manual investigation:
- QUEUED but has startedAt (impossible - can't start while queued)
- RUNNING but no startedAt (impossible - can't run without starting)
⚠️ NO BULK ACTIONS PROVIDED - These need case-by-case investigation.
Each invalid execution likely has a different root cause (crashes, race conditions,
DB corruption). Investigate the execution history and logs to determine appropriate
action (manual cleanup, status fix, or leave as-is if system recovered).
Args:
limit: Maximum number of executions to return (default 100)
offset: Number of executions to skip (default 0)
Returns:
List of invalid state executions with details
"""
logger.info(f"Listing invalid state executions (limit={limit}, offset={offset})")
executions = await get_invalid_executions_details(limit=limit, offset=offset)
# Get total count for pagination
diagnostics = await get_execution_diagnostics()
total = (
diagnostics.invalid_queued_with_start
+ diagnostics.invalid_running_without_start
)
return RunningExecutionsListResponse(executions=executions, total=total)
@router.post(
"/diagnostics/executions/requeue",
response_model=RequeueExecutionResponse,
summary="Requeue Stuck Execution",
)
async def requeue_single_execution(
request: StopExecutionRequest, # Reuse same request model (has execution_id)
user: AuthUser = Security(requires_admin_user),
):
"""
Requeue a stuck QUEUED execution (admin only).
Uses add_graph_execution with existing graph_exec_id to requeue.
⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits.
Args:
request: Contains execution_id to requeue
Returns:
Success status and message
"""
logger.info(f"Admin {user.user_id} requeueing execution {request.execution_id}")
# Get the execution (validation - must be QUEUED)
executions = await get_graph_executions(
graph_exec_id=request.execution_id,
statuses=[AgentExecutionStatus.QUEUED],
)
if not executions:
raise HTTPException(
status_code=404,
detail="Execution not found or not in QUEUED status",
)
execution = executions[0]
# Use add_graph_execution in requeue mode
await add_graph_execution(
graph_id=execution.graph_id,
user_id=execution.user_id,
graph_version=execution.graph_version,
graph_exec_id=request.execution_id, # Requeue existing execution
)
return RequeueExecutionResponse(
success=True,
requeued_count=1,
message="Execution requeued successfully",
)
@router.post(
"/diagnostics/executions/requeue-bulk",
response_model=RequeueExecutionResponse,
summary="Requeue Multiple Stuck Executions",
)
async def requeue_multiple_executions(
request: StopExecutionsRequest, # Reuse same request model (has execution_ids)
user: AuthUser = Security(requires_admin_user),
):
"""
Requeue multiple stuck QUEUED executions (admin only).
Uses add_graph_execution with existing graph_exec_id to requeue.
⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits.
Args:
request: Contains list of execution_ids to requeue
Returns:
Number of executions requeued and success message
"""
logger.info(
f"Admin {user.user_id} requeueing {len(request.execution_ids)} executions"
)
# Get executions by ID list (must be QUEUED)
executions = await get_graph_executions(
execution_ids=request.execution_ids,
statuses=[AgentExecutionStatus.QUEUED],
)
if not executions:
return RequeueExecutionResponse(
success=False,
requeued_count=0,
message="No QUEUED executions found to requeue",
)
# Requeue all executions in parallel using add_graph_execution
async def requeue_one(exec) -> bool:
try:
await add_graph_execution(
graph_id=exec.graph_id,
user_id=exec.user_id,
graph_version=exec.graph_version,
graph_exec_id=exec.id, # Requeue existing
)
return True
except Exception as e:
logger.error(f"Failed to requeue {exec.id}: {e}")
return False
results = await asyncio.gather(
*[requeue_one(exec) for exec in executions], return_exceptions=False
)
requeued_count = sum(1 for success in results if success)
return RequeueExecutionResponse(
success=requeued_count > 0,
requeued_count=requeued_count,
message=f"Requeued {requeued_count} of {len(request.execution_ids)} executions",
)
@router.post(
"/diagnostics/executions/stop",
response_model=StopExecutionResponse,
summary="Stop Single Execution",
)
async def stop_single_execution(
request: StopExecutionRequest,
user: AuthUser = Security(requires_admin_user),
):
"""
Stop a single execution (admin only).
Uses robust stop_graph_execution which cascades to children and waits for termination.
Args:
request: Contains execution_id to stop
Returns:
Success status and message
"""
logger.info(f"Admin {user.user_id} stopping execution {request.execution_id}")
# Get the execution to find its owner user_id (required by stop_graph_execution)
executions = await get_graph_executions(
graph_exec_id=request.execution_id,
)
if not executions:
raise HTTPException(status_code=404, detail="Execution not found")
execution = executions[0]
# Use robust stop_graph_execution (cascades to children, waits for termination)
await stop_graph_execution(
user_id=execution.user_id,
graph_exec_id=request.execution_id,
wait_timeout=15.0,
cascade=True,
)
return StopExecutionResponse(
success=True,
stopped_count=1,
message="Execution stopped successfully",
)
@router.post(
"/diagnostics/executions/stop-bulk",
response_model=StopExecutionResponse,
summary="Stop Multiple Executions",
)
async def stop_multiple_executions(
request: StopExecutionsRequest,
user: AuthUser = Security(requires_admin_user),
):
"""
Stop multiple active executions (admin only).
Uses robust stop_graph_execution which cascades to children and waits for termination.
Args:
request: Contains list of execution_ids to stop
Returns:
Number of executions stopped and success message
"""
logger.info(
f"Admin {user.user_id} stopping {len(request.execution_ids)} executions"
)
# Get executions by ID list
executions = await get_graph_executions(
execution_ids=request.execution_ids,
)
if not executions:
return StopExecutionResponse(
success=False,
stopped_count=0,
message="No executions found",
)
# Stop all executions in parallel using robust stop_graph_execution
async def stop_one(exec) -> bool:
try:
await stop_graph_execution(
user_id=exec.user_id,
graph_exec_id=exec.id,
wait_timeout=15.0,
cascade=True,
)
return True
except Exception as e:
logger.error(f"Failed to stop execution {exec.id}: {e}")
return False
results = await asyncio.gather(
*[stop_one(exec) for exec in executions], return_exceptions=False
)
stopped_count = sum(1 for success in results if success)
return StopExecutionResponse(
success=stopped_count > 0,
stopped_count=stopped_count,
message=f"Stopped {stopped_count} of {len(request.execution_ids)} executions",
)
@router.post(
"/diagnostics/executions/cleanup-orphaned",
response_model=StopExecutionResponse,
summary="Cleanup Orphaned Executions",
)
async def cleanup_orphaned_executions(
request: StopExecutionsRequest,
user: AuthUser = Security(requires_admin_user),
):
"""
Cleanup orphaned executions by directly updating DB status (admin only).
For executions in DB but not actually running in executor (old/stale records).
Args:
request: Contains list of execution_ids to cleanup
Returns:
Number of executions cleaned up and success message
"""
logger.info(
f"Admin {user.user_id} cleaning up {len(request.execution_ids)} orphaned executions"
)
cleaned_count = await cleanup_orphaned_executions_bulk(
request.execution_ids, user.user_id
)
return StopExecutionResponse(
success=cleaned_count > 0,
stopped_count=cleaned_count,
message=f"Cleaned up {cleaned_count} of {len(request.execution_ids)} orphaned executions",
)
# ============================================================================
# SCHEDULE DIAGNOSTICS ENDPOINTS
# ============================================================================
class SchedulesListResponse(BaseModel):
"""Response model for list of schedules"""
schedules: List[ScheduleDetail]
total: int
class OrphanedSchedulesListResponse(BaseModel):
"""Response model for list of orphaned schedules"""
schedules: List[OrphanedScheduleDetail]
total: int
class ScheduleCleanupRequest(BaseModel):
"""Request model for cleaning up schedules"""
schedule_ids: List[str]
class ScheduleCleanupResponse(BaseModel):
"""Response model for schedule cleanup operations"""
success: bool
deleted_count: int = 0
message: str
@router.get(
"/diagnostics/schedules",
response_model=ScheduleHealthMetrics,
summary="Get Schedule Diagnostics",
)
async def get_schedule_diagnostics_endpoint():
"""
Get comprehensive diagnostic information about schedule health.
Returns schedule metrics including:
- Total schedules (user vs system)
- Orphaned schedules by category
- Upcoming executions
"""
logger.info("Getting schedule diagnostics")
diagnostics = await get_schedule_health_metrics()
logger.info(
f"Schedule diagnostics: total={diagnostics.total_schedules}, "
f"user={diagnostics.user_schedules}, "
f"orphaned={diagnostics.total_orphaned}"
)
return diagnostics
@router.get(
"/diagnostics/schedules/all",
response_model=SchedulesListResponse,
summary="List All User Schedules",
)
async def list_all_schedules(
limit: int = 100,
offset: int = 0,
):
"""
Get detailed list of all user schedules (excludes system monitoring jobs).
Args:
limit: Maximum number of schedules to return (default 100)
offset: Number of schedules to skip (default 0)
Returns:
List of schedules with details
"""
logger.info(f"Listing all schedules (limit={limit}, offset={offset})")
schedules = await get_all_schedules_details(limit=limit, offset=offset)
# Get total count
diagnostics = await get_schedule_health_metrics()
total = diagnostics.user_schedules
return SchedulesListResponse(schedules=schedules, total=total)
@router.get(
"/diagnostics/schedules/orphaned",
response_model=OrphanedSchedulesListResponse,
summary="List Orphaned Schedules",
)
async def list_orphaned_schedules():
"""
Get detailed list of orphaned schedules with orphan reasons.
Returns:
List of orphaned schedules categorized by orphan type
"""
logger.info("Listing orphaned schedules")
schedules = await get_orphaned_schedules_details()
return OrphanedSchedulesListResponse(schedules=schedules, total=len(schedules))
@router.post(
"/diagnostics/schedules/cleanup-orphaned",
response_model=ScheduleCleanupResponse,
summary="Cleanup Orphaned Schedules",
)
async def cleanup_orphaned_schedules(
request: ScheduleCleanupRequest,
user: AuthUser = Security(requires_admin_user),
):
"""
Cleanup orphaned schedules by deleting from scheduler (admin only).
Args:
request: Contains list of schedule_ids to delete
Returns:
Number of schedules deleted and success message
"""
logger.info(
f"Admin {user.user_id} cleaning up {len(request.schedule_ids)} orphaned schedules"
)
deleted_count = await cleanup_orphaned_schedules_bulk(
request.schedule_ids, user.user_id
)
return ScheduleCleanupResponse(
success=deleted_count > 0,
deleted_count=deleted_count,
message=f"Deleted {deleted_count} of {len(request.schedule_ids)} orphaned schedules",
)
@router.post(
"/diagnostics/executions/stop-all-long-running",
response_model=StopExecutionResponse,
summary="Stop ALL Long-Running Executions",
)
async def stop_all_long_running_executions_endpoint(
user: AuthUser = Security(requires_admin_user),
):
"""
Stop ALL long-running executions (RUNNING >24h) by sending cancel signals (admin only).
Operates on entire dataset, not limited to pagination.
Returns:
Number of executions stopped and success message
"""
logger.info(f"Admin {user.user_id} stopping ALL long-running executions")
stopped_count = await stop_all_long_running_executions(user.user_id)
return StopExecutionResponse(
success=stopped_count > 0,
stopped_count=stopped_count,
message=f"Stopped {stopped_count} long-running executions",
)
@router.post(
"/diagnostics/executions/cleanup-all-orphaned",
response_model=StopExecutionResponse,
summary="Cleanup ALL Orphaned Executions",
)
async def cleanup_all_orphaned_executions(
user: AuthUser = Security(requires_admin_user),
):
"""
Cleanup ALL orphaned executions (>24h old) by directly updating DB status.
Operates on all executions, not just paginated results.
Returns:
Number of executions cleaned up and success message
"""
logger.info(f"Admin {user.user_id} cleaning up ALL orphaned executions")
# Fetch all orphaned execution IDs
execution_ids = await get_all_orphaned_execution_ids()
if not execution_ids:
return StopExecutionResponse(
success=True,
stopped_count=0,
message="No orphaned executions to cleanup",
)
cleaned_count = await cleanup_orphaned_executions_bulk(execution_ids, user.user_id)
return StopExecutionResponse(
success=cleaned_count > 0,
stopped_count=cleaned_count,
message=f"Cleaned up {cleaned_count} orphaned executions",
)
@router.post(
"/diagnostics/executions/cleanup-all-stuck-queued",
response_model=StopExecutionResponse,
summary="Cleanup ALL Stuck Queued Executions",
)
async def cleanup_all_stuck_queued_executions_endpoint(
user: AuthUser = Security(requires_admin_user),
):
"""
Cleanup ALL stuck queued executions (QUEUED >1h) by updating DB status (admin only).
Operates on entire dataset, not limited to pagination.
Returns:
Number of executions cleaned up and success message
"""
logger.info(f"Admin {user.user_id} cleaning up ALL stuck queued executions")
cleaned_count = await cleanup_all_stuck_queued_executions(user.user_id)
return StopExecutionResponse(
success=cleaned_count > 0,
stopped_count=cleaned_count,
message=f"Cleaned up {cleaned_count} stuck queued executions",
)
@router.post(
"/diagnostics/executions/requeue-all-stuck",
response_model=RequeueExecutionResponse,
summary="Requeue ALL Stuck Queued Executions",
)
async def requeue_all_stuck_executions(
user: AuthUser = Security(requires_admin_user),
):
"""
Requeue ALL stuck queued executions (QUEUED >1h) by publishing to RabbitMQ.
Operates on all executions, not just paginated results.
Uses add_graph_execution with existing graph_exec_id to requeue.
⚠️ WARNING: This will re-execute ALL stuck executions and may cost significant credits.
Returns:
Number of executions requeued and success message
"""
logger.info(f"Admin {user.user_id} requeueing ALL stuck queued executions")
# Fetch all stuck queued execution IDs
execution_ids = await get_all_stuck_queued_execution_ids()
if not execution_ids:
return RequeueExecutionResponse(
success=True,
requeued_count=0,
message="No stuck queued executions to requeue",
)
# Get stuck executions by ID list (must be QUEUED)
executions = await get_graph_executions(
execution_ids=execution_ids,
statuses=[AgentExecutionStatus.QUEUED],
)
# Requeue all in parallel using add_graph_execution
async def requeue_one(exec) -> bool:
try:
await add_graph_execution(
graph_id=exec.graph_id,
user_id=exec.user_id,
graph_version=exec.graph_version,
graph_exec_id=exec.id, # Requeue existing
)
return True
except Exception as e:
logger.error(f"Failed to requeue {exec.id}: {e}")
return False
results = await asyncio.gather(
*[requeue_one(exec) for exec in executions], return_exceptions=False
)
requeued_count = sum(1 for success in results if success)
return RequeueExecutionResponse(
success=requeued_count > 0,
requeued_count=requeued_count,
message=f"Requeued {requeued_count} stuck executions",
)

View File

@@ -1,889 +0,0 @@
from datetime import datetime, timezone
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from prisma.enums import AgentExecutionStatus
import backend.api.features.admin.diagnostics_admin_routes as diagnostics_admin_routes
from backend.data.diagnostics import (
AgentDiagnosticsSummary,
ExecutionDiagnosticsSummary,
FailedExecutionDetail,
OrphanedScheduleDetail,
RunningExecutionDetail,
ScheduleDetail,
ScheduleHealthMetrics,
)
from backend.data.execution import GraphExecutionMeta
app = fastapi.FastAPI()
app.include_router(diagnostics_admin_routes.router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
"""Setup admin auth overrides for all tests in this module"""
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def test_get_execution_diagnostics_success(
mocker: pytest_mock.MockFixture,
):
"""Test fetching execution diagnostics with invalid state detection"""
mock_diagnostics = ExecutionDiagnosticsSummary(
running_count=10,
queued_db_count=5,
rabbitmq_queue_depth=3,
cancel_queue_depth=0,
orphaned_running=2,
orphaned_queued=1,
failed_count_1h=5,
failed_count_24h=20,
failure_rate_24h=0.83,
stuck_running_24h=1,
stuck_running_1h=3,
oldest_running_hours=26.5,
stuck_queued_1h=2,
queued_never_started=1,
invalid_queued_with_start=1, # New invalid state
invalid_running_without_start=1, # New invalid state
completed_1h=50,
completed_24h=1200,
throughput_per_hour=50.0,
timestamp=datetime.now(timezone.utc).isoformat(),
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
return_value=mock_diagnostics,
)
response = client.get("/admin/diagnostics/executions")
assert response.status_code == 200
data = response.json()
# Verify new invalid state fields are included
assert data["invalid_queued_with_start"] == 1
assert data["invalid_running_without_start"] == 1
# Verify all expected fields present
assert "running_executions" in data
assert "orphaned_running" in data
assert "failed_count_24h" in data
def test_list_invalid_executions(
mocker: pytest_mock.MockFixture,
):
"""Test listing executions in invalid states (read-only endpoint)"""
mock_invalid_executions = [
RunningExecutionDetail(
execution_id="exec-invalid-1",
graph_id="graph-123",
graph_name="Test Graph",
graph_version=1,
user_id="user-123",
user_email="test@example.com",
status="QUEUED",
created_at=datetime.now(timezone.utc),
started_at=datetime.now(
timezone.utc
), # QUEUED but has startedAt - INVALID!
queue_status=None,
),
RunningExecutionDetail(
execution_id="exec-invalid-2",
graph_id="graph-456",
graph_name="Another Graph",
graph_version=2,
user_id="user-456",
user_email="user@example.com",
status="RUNNING",
created_at=datetime.now(timezone.utc),
started_at=None, # RUNNING but no startedAt - INVALID!
queue_status=None,
),
]
mock_diagnostics = ExecutionDiagnosticsSummary(
running_count=10,
queued_db_count=5,
rabbitmq_queue_depth=3,
cancel_queue_depth=0,
orphaned_running=0,
orphaned_queued=0,
failed_count_1h=0,
failed_count_24h=0,
failure_rate_24h=0.0,
stuck_running_24h=0,
stuck_running_1h=0,
oldest_running_hours=None,
stuck_queued_1h=0,
queued_never_started=0,
invalid_queued_with_start=1,
invalid_running_without_start=1,
completed_1h=0,
completed_24h=0,
throughput_per_hour=0.0,
timestamp=datetime.now(timezone.utc).isoformat(),
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_invalid_executions_details",
return_value=mock_invalid_executions,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
return_value=mock_diagnostics,
)
response = client.get("/admin/diagnostics/executions/invalid?limit=100&offset=0")
assert response.status_code == 200
data = response.json()
assert data["total"] == 2 # Sum of both invalid state types
assert len(data["executions"]) == 2
# Verify both types of invalid states are returned
assert data["executions"][0]["execution_id"] in [
"exec-invalid-1",
"exec-invalid-2",
]
assert data["executions"][1]["execution_id"] in [
"exec-invalid-1",
"exec-invalid-2",
]
def test_requeue_single_execution_with_add_graph_execution(
mocker: pytest_mock.MockFixture,
admin_user_id: str,
):
"""Test requeueing uses add_graph_execution in requeue mode"""
mock_exec_meta = GraphExecutionMeta(
id="exec-stuck-123",
user_id="user-123",
graph_id="graph-456",
graph_version=1,
inputs=None,
credential_inputs=None,
nodes_input_masks=None,
preset_id=None,
status=AgentExecutionStatus.QUEUED,
started_at=datetime.now(timezone.utc),
ended_at=datetime.now(timezone.utc),
stats=None,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
return_value=[mock_exec_meta],
)
mock_add_graph_execution = mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.add_graph_execution",
return_value=AsyncMock(),
)
response = client.post(
"/admin/diagnostics/executions/requeue",
json={"execution_id": "exec-stuck-123"},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["requeued_count"] == 1
# Verify it used add_graph_execution in requeue mode
mock_add_graph_execution.assert_called_once()
call_kwargs = mock_add_graph_execution.call_args.kwargs
assert call_kwargs["graph_exec_id"] == "exec-stuck-123" # Requeue mode!
assert call_kwargs["graph_id"] == "graph-456"
assert call_kwargs["user_id"] == "user-123"
def test_stop_single_execution_with_stop_graph_execution(
mocker: pytest_mock.MockFixture,
admin_user_id: str,
):
"""Test stopping uses robust stop_graph_execution"""
mock_exec_meta = GraphExecutionMeta(
id="exec-running-123",
user_id="user-789",
graph_id="graph-999",
graph_version=2,
inputs=None,
credential_inputs=None,
nodes_input_masks=None,
preset_id=None,
status=AgentExecutionStatus.RUNNING,
started_at=datetime.now(timezone.utc),
ended_at=datetime.now(timezone.utc),
stats=None,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
return_value=[mock_exec_meta],
)
mock_stop_graph_execution = mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.stop_graph_execution",
return_value=AsyncMock(),
)
response = client.post(
"/admin/diagnostics/executions/stop",
json={"execution_id": "exec-running-123"},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["stopped_count"] == 1
# Verify it used stop_graph_execution with cascade
mock_stop_graph_execution.assert_called_once()
call_kwargs = mock_stop_graph_execution.call_args.kwargs
assert call_kwargs["graph_exec_id"] == "exec-running-123"
assert call_kwargs["user_id"] == "user-789"
assert call_kwargs["cascade"] is True # Stops children too!
assert call_kwargs["wait_timeout"] == 15.0
def test_requeue_not_queued_execution_fails(
mocker: pytest_mock.MockFixture,
):
"""Test that requeue fails if execution is not in QUEUED status"""
# Mock an execution that's RUNNING (not QUEUED)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
return_value=[], # No QUEUED executions found
)
response = client.post(
"/admin/diagnostics/executions/requeue",
json={"execution_id": "exec-running-123"},
)
assert response.status_code == 404
assert "not found or not in QUEUED status" in response.json()["detail"]
def test_list_invalid_executions_no_bulk_actions(
mocker: pytest_mock.MockFixture,
):
"""Verify invalid executions endpoint is read-only (no bulk actions)"""
# This is a documentation test - the endpoint exists but should not
# have corresponding cleanup/stop/requeue endpoints
# These endpoints should NOT exist for invalid states:
invalid_bulk_endpoints = [
"/admin/diagnostics/executions/cleanup-invalid",
"/admin/diagnostics/executions/stop-invalid",
"/admin/diagnostics/executions/requeue-invalid",
]
for endpoint in invalid_bulk_endpoints:
response = client.post(endpoint, json={"execution_ids": ["test"]})
assert response.status_code == 404, f"{endpoint} should not exist (read-only)"
def test_execution_ids_filter_efficiency(
mocker: pytest_mock.MockFixture,
):
"""Test that bulk operations use efficient execution_ids filter"""
mock_exec_metas = [
GraphExecutionMeta(
id=f"exec-{i}",
user_id=f"user-{i}",
graph_id="graph-123",
graph_version=1,
inputs=None,
credential_inputs=None,
nodes_input_masks=None,
preset_id=None,
status=AgentExecutionStatus.QUEUED,
started_at=datetime.now(timezone.utc),
ended_at=datetime.now(timezone.utc),
stats=None,
)
for i in range(3)
]
mock_get_graph_executions = mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
return_value=mock_exec_metas,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.add_graph_execution",
return_value=AsyncMock(),
)
response = client.post(
"/admin/diagnostics/executions/requeue-bulk",
json={"execution_ids": ["exec-0", "exec-1", "exec-2"]},
)
assert response.status_code == 200
# Verify it used execution_ids filter (not fetching all queued)
mock_get_graph_executions.assert_called_once()
call_kwargs = mock_get_graph_executions.call_args.kwargs
assert "execution_ids" in call_kwargs
assert call_kwargs["execution_ids"] == ["exec-0", "exec-1", "exec-2"]
assert call_kwargs["statuses"] == [AgentExecutionStatus.QUEUED]
# ---------------------------------------------------------------------------
# Helper: reusable mock diagnostics summary
# ---------------------------------------------------------------------------
def _make_mock_diagnostics(**overrides) -> ExecutionDiagnosticsSummary:
defaults = dict(
running_count=10,
queued_db_count=5,
rabbitmq_queue_depth=3,
cancel_queue_depth=0,
orphaned_running=2,
orphaned_queued=1,
failed_count_1h=5,
failed_count_24h=20,
failure_rate_24h=0.83,
stuck_running_24h=3,
stuck_running_1h=5,
oldest_running_hours=26.5,
stuck_queued_1h=2,
queued_never_started=1,
invalid_queued_with_start=1,
invalid_running_without_start=1,
completed_1h=50,
completed_24h=1200,
throughput_per_hour=50.0,
timestamp=datetime.now(timezone.utc).isoformat(),
)
defaults.update(overrides)
return ExecutionDiagnosticsSummary(**defaults)
_SENTINEL = object()
def _make_mock_execution(
exec_id: str = "exec-1",
status: str = "RUNNING",
started_at: datetime | None | object = _SENTINEL,
) -> RunningExecutionDetail:
return RunningExecutionDetail(
execution_id=exec_id,
graph_id="graph-123",
graph_name="Test Graph",
graph_version=1,
user_id="user-123",
user_email="test@example.com",
status=status,
created_at=datetime.now(timezone.utc),
started_at=(
datetime.now(timezone.utc) if started_at is _SENTINEL else started_at
),
queue_status=None,
)
def _make_mock_failed_execution(
exec_id: str = "exec-fail-1",
) -> FailedExecutionDetail:
return FailedExecutionDetail(
execution_id=exec_id,
graph_id="graph-123",
graph_name="Test Graph",
graph_version=1,
user_id="user-123",
user_email="test@example.com",
status="FAILED",
created_at=datetime.now(timezone.utc),
started_at=datetime.now(timezone.utc),
failed_at=datetime.now(timezone.utc),
error_message="Something went wrong",
)
def _make_mock_schedule_health(**overrides) -> ScheduleHealthMetrics:
defaults = dict(
total_schedules=15,
user_schedules=10,
system_schedules=5,
orphaned_deleted_graph=2,
orphaned_no_library_access=1,
orphaned_invalid_credentials=0,
orphaned_validation_failed=0,
total_orphaned=3,
schedules_next_hour=4,
schedules_next_24h=8,
total_runs_next_hour=12,
total_runs_next_24h=48,
timestamp=datetime.now(timezone.utc).isoformat(),
)
defaults.update(overrides)
return ScheduleHealthMetrics(**defaults)
# ---------------------------------------------------------------------------
# GET endpoints: execution list variants
# ---------------------------------------------------------------------------
def test_list_running_executions(mocker: pytest_mock.MockFixture):
mock_execs = [
_make_mock_execution("exec-run-1"),
_make_mock_execution("exec-run-2"),
]
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_running_executions_details",
return_value=mock_execs,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
return_value=_make_mock_diagnostics(),
)
response = client.get("/admin/diagnostics/executions/running?limit=50&offset=0")
assert response.status_code == 200
data = response.json()
assert data["total"] == 15 # running_count(10) + queued_db_count(5)
assert len(data["executions"]) == 2
assert data["executions"][0]["execution_id"] == "exec-run-1"
def test_list_orphaned_executions(mocker: pytest_mock.MockFixture):
mock_execs = [_make_mock_execution("exec-orphan-1", status="RUNNING")]
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_orphaned_executions_details",
return_value=mock_execs,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
return_value=_make_mock_diagnostics(),
)
response = client.get("/admin/diagnostics/executions/orphaned?limit=50&offset=0")
assert response.status_code == 200
data = response.json()
assert data["total"] == 3 # orphaned_running(2) + orphaned_queued(1)
assert len(data["executions"]) == 1
def test_list_failed_executions(mocker: pytest_mock.MockFixture):
mock_execs = [_make_mock_failed_execution("exec-fail-1")]
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_failed_executions_details",
return_value=mock_execs,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_failed_executions_count",
return_value=42,
)
response = client.get(
"/admin/diagnostics/executions/failed?limit=50&offset=0&hours=24"
)
assert response.status_code == 200
data = response.json()
assert data["total"] == 42
assert len(data["executions"]) == 1
assert data["executions"][0]["error_message"] == "Something went wrong"
def test_list_long_running_executions(mocker: pytest_mock.MockFixture):
mock_execs = [_make_mock_execution("exec-long-1")]
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_long_running_executions_details",
return_value=mock_execs,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
return_value=_make_mock_diagnostics(),
)
response = client.get(
"/admin/diagnostics/executions/long-running?limit=50&offset=0"
)
assert response.status_code == 200
data = response.json()
assert data["total"] == 3 # stuck_running_24h
assert len(data["executions"]) == 1
def test_list_stuck_queued_executions(mocker: pytest_mock.MockFixture):
mock_execs = [
_make_mock_execution("exec-stuck-1", status="QUEUED", started_at=None)
]
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_stuck_queued_executions_details",
return_value=mock_execs,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
return_value=_make_mock_diagnostics(),
)
response = client.get(
"/admin/diagnostics/executions/stuck-queued?limit=50&offset=0"
)
assert response.status_code == 200
data = response.json()
assert data["total"] == 2 # stuck_queued_1h
assert len(data["executions"]) == 1
# ---------------------------------------------------------------------------
# GET endpoints: agent + schedule diagnostics
# ---------------------------------------------------------------------------
def test_get_agent_diagnostics(mocker: pytest_mock.MockFixture):
mock_diag = AgentDiagnosticsSummary(
agents_with_active_executions=7,
timestamp=datetime.now(timezone.utc).isoformat(),
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_agent_diagnostics",
return_value=mock_diag,
)
response = client.get("/admin/diagnostics/agents")
assert response.status_code == 200
data = response.json()
assert data["agents_with_active_executions"] == 7
def test_get_schedule_diagnostics(mocker: pytest_mock.MockFixture):
mock_metrics = _make_mock_schedule_health()
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_schedule_health_metrics",
return_value=mock_metrics,
)
response = client.get("/admin/diagnostics/schedules")
assert response.status_code == 200
data = response.json()
assert data["user_schedules"] == 10
assert data["total_orphaned"] == 3
assert data["total_runs_next_hour"] == 12
def test_list_all_schedules(mocker: pytest_mock.MockFixture):
mock_schedules = [
ScheduleDetail(
schedule_id="sched-1",
schedule_name="Daily Run",
graph_id="graph-1",
graph_name="My Agent",
graph_version=1,
user_id="user-1",
user_email="alice@example.com",
cron="0 9 * * *",
timezone="UTC",
next_run_time=datetime.now(timezone.utc).isoformat(),
),
]
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_all_schedules_details",
return_value=mock_schedules,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_schedule_health_metrics",
return_value=_make_mock_schedule_health(),
)
response = client.get("/admin/diagnostics/schedules/all?limit=50&offset=0")
assert response.status_code == 200
data = response.json()
assert data["total"] == 10
assert len(data["schedules"]) == 1
assert data["schedules"][0]["schedule_name"] == "Daily Run"
def test_list_orphaned_schedules(mocker: pytest_mock.MockFixture):
mock_orphans = [
OrphanedScheduleDetail(
schedule_id="sched-orphan-1",
schedule_name="Ghost Schedule",
graph_id="graph-deleted",
graph_version=1,
user_id="user-1",
orphan_reason="deleted_graph",
error_detail=None,
next_run_time=datetime.now(timezone.utc).isoformat(),
),
]
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_orphaned_schedules_details",
return_value=mock_orphans,
)
response = client.get("/admin/diagnostics/schedules/orphaned")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["schedules"][0]["orphan_reason"] == "deleted_graph"
# ---------------------------------------------------------------------------
# POST endpoints: bulk stop, cleanup, requeue
# ---------------------------------------------------------------------------
def test_stop_multiple_executions(mocker: pytest_mock.MockFixture):
mock_exec_metas = [
GraphExecutionMeta(
id=f"exec-{i}",
user_id=f"user-{i}",
graph_id="graph-123",
graph_version=1,
inputs=None,
credential_inputs=None,
nodes_input_masks=None,
preset_id=None,
status=AgentExecutionStatus.RUNNING,
started_at=datetime.now(timezone.utc),
ended_at=None,
stats=None,
)
for i in range(2)
]
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
return_value=mock_exec_metas,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.stop_graph_execution",
return_value=AsyncMock(),
)
response = client.post(
"/admin/diagnostics/executions/stop-bulk",
json={"execution_ids": ["exec-0", "exec-1"]},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["stopped_count"] == 2
def test_stop_multiple_executions_none_found(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
return_value=[],
)
response = client.post(
"/admin/diagnostics/executions/stop-bulk",
json={"execution_ids": ["nonexistent"]},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
assert data["stopped_count"] == 0
def test_cleanup_orphaned_executions(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_executions_bulk",
return_value=3,
)
response = client.post(
"/admin/diagnostics/executions/cleanup-orphaned",
json={"execution_ids": ["exec-1", "exec-2", "exec-3"]},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["stopped_count"] == 3
def test_cleanup_orphaned_schedules(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_schedules_bulk",
return_value=2,
)
response = client.post(
"/admin/diagnostics/schedules/cleanup-orphaned",
json={"schedule_ids": ["sched-1", "sched-2"]},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["deleted_count"] == 2
def test_stop_all_long_running_executions(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.stop_all_long_running_executions",
return_value=5,
)
response = client.post("/admin/diagnostics/executions/stop-all-long-running")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["stopped_count"] == 5
def test_cleanup_all_orphaned_executions(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_all_orphaned_execution_ids",
return_value=["exec-1", "exec-2"],
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_executions_bulk",
return_value=2,
)
response = client.post("/admin/diagnostics/executions/cleanup-all-orphaned")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["stopped_count"] == 2
def test_cleanup_all_orphaned_executions_none(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_all_orphaned_execution_ids",
return_value=[],
)
response = client.post("/admin/diagnostics/executions/cleanup-all-orphaned")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["stopped_count"] == 0
assert "No orphaned" in data["message"]
def test_cleanup_all_stuck_queued_executions(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.cleanup_all_stuck_queued_executions",
return_value=4,
)
response = client.post("/admin/diagnostics/executions/cleanup-all-stuck-queued")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["stopped_count"] == 4
def test_requeue_all_stuck_executions(mocker: pytest_mock.MockFixture):
mock_exec_metas = [
GraphExecutionMeta(
id=f"exec-stuck-{i}",
user_id=f"user-{i}",
graph_id="graph-123",
graph_version=1,
inputs=None,
credential_inputs=None,
nodes_input_masks=None,
preset_id=None,
status=AgentExecutionStatus.QUEUED,
started_at=None,
ended_at=None,
stats=None,
)
for i in range(3)
]
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_all_stuck_queued_execution_ids",
return_value=["exec-stuck-0", "exec-stuck-1", "exec-stuck-2"],
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
return_value=mock_exec_metas,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.add_graph_execution",
return_value=AsyncMock(),
)
response = client.post("/admin/diagnostics/executions/requeue-all-stuck")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["requeued_count"] == 3
def test_requeue_all_stuck_executions_none(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_all_stuck_queued_execution_ids",
return_value=[],
)
response = client.post("/admin/diagnostics/executions/requeue-all-stuck")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["requeued_count"] == 0
assert "No stuck" in data["message"]
def test_requeue_bulk_none_found(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
return_value=[],
)
response = client.post(
"/admin/diagnostics/executions/requeue-bulk",
json={"execution_ids": ["nonexistent"]},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
assert data["requeued_count"] == 0
def test_stop_single_execution_not_found(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
return_value=[],
)
response = client.post(
"/admin/diagnostics/executions/stop",
json={"execution_id": "nonexistent"},
)
assert response.status_code == 404
assert "not found" in response.json()["detail"]

View File

@@ -14,70 +14,3 @@ class UserHistoryResponse(BaseModel):
class AddUserCreditsResponse(BaseModel):
new_balance: int
transaction_key: str
class ExecutionDiagnosticsResponse(BaseModel):
"""Response model for execution diagnostics"""
# Current execution state
running_executions: int
queued_executions_db: int
queued_executions_rabbitmq: int
cancel_queue_depth: int
# Orphaned execution detection
orphaned_running: int
orphaned_queued: int
# Failure metrics
failed_count_1h: int
failed_count_24h: int
failure_rate_24h: float
# Long-running detection
stuck_running_24h: int
stuck_running_1h: int
oldest_running_hours: float | None
# Stuck queued detection
stuck_queued_1h: int
queued_never_started: int
# Invalid state detection (data corruption - no auto-actions)
invalid_queued_with_start: int
invalid_running_without_start: int
# Throughput metrics
completed_1h: int
completed_24h: int
throughput_per_hour: float
timestamp: str
class AgentDiagnosticsResponse(BaseModel):
"""Response model for agent diagnostics"""
agents_with_active_executions: int
timestamp: str
class ScheduleHealthMetrics(BaseModel):
"""Response model for schedule diagnostics"""
total_schedules: int
user_schedules: int
system_schedules: int
# Orphan detection
orphaned_deleted_graph: int
orphaned_no_library_access: int
orphaned_invalid_credentials: int
orphaned_validation_failed: int
total_orphaned: int
# Upcoming
schedules_next_hour: int
schedules_next_24h: int
timestamp: str

View File

@@ -43,7 +43,6 @@ async def get_cost_dashboard(
model: str | None = Query(None),
block_name: str | None = Query(None),
tracking_type: str | None = Query(None),
graph_exec_id: str | None = Query(None),
):
logger.info("Admin %s fetching platform cost dashboard", admin_user_id)
return await get_platform_cost_dashboard(
@@ -54,7 +53,6 @@ async def get_cost_dashboard(
model=model,
block_name=block_name,
tracking_type=tracking_type,
graph_exec_id=graph_exec_id,
)
@@ -74,7 +72,6 @@ async def get_cost_logs(
model: str | None = Query(None),
block_name: str | None = Query(None),
tracking_type: str | None = Query(None),
graph_exec_id: str | None = Query(None),
):
logger.info("Admin %s fetching platform cost logs", admin_user_id)
logs, total = await get_platform_cost_logs(
@@ -87,7 +84,6 @@ async def get_cost_logs(
model=model,
block_name=block_name,
tracking_type=tracking_type,
graph_exec_id=graph_exec_id,
)
total_pages = (total + page_size - 1) // page_size
return PlatformCostLogsResponse(
@@ -121,7 +117,6 @@ async def export_cost_logs(
model: str | None = Query(None),
block_name: str | None = Query(None),
tracking_type: str | None = Query(None),
graph_exec_id: str | None = Query(None),
):
logger.info("Admin %s exporting platform cost logs", admin_user_id)
logs, truncated = await get_platform_cost_logs_for_export(
@@ -132,7 +127,6 @@ async def export_cost_logs(
model=model,
block_name=block_name,
tracking_type=tracking_type,
graph_exec_id=graph_exec_id,
)
return PlatformCostExportResponse(
logs=logs,

View File

@@ -32,10 +32,10 @@ router = APIRouter(
class UserRateLimitResponse(BaseModel):
user_id: str
user_email: Optional[str] = None
daily_cost_limit_microdollars: int
weekly_cost_limit_microdollars: int
daily_cost_used_microdollars: int
weekly_cost_used_microdollars: int
daily_token_limit: int
weekly_token_limit: int
daily_tokens_used: int
weekly_tokens_used: int
tier: SubscriptionTier
@@ -101,19 +101,17 @@ async def get_user_rate_limit(
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
daily_limit, weekly_limit, tier = await get_global_rate_limits(
resolved_id,
config.daily_cost_limit_microdollars,
config.weekly_cost_limit_microdollars,
resolved_id, config.daily_token_limit, config.weekly_token_limit
)
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier)
return UserRateLimitResponse(
user_id=resolved_id,
user_email=resolved_email,
daily_cost_limit_microdollars=daily_limit,
weekly_cost_limit_microdollars=weekly_limit,
daily_cost_used_microdollars=usage.daily.used,
weekly_cost_used_microdollars=usage.weekly.used,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,
weekly_tokens_used=usage.weekly.used,
tier=tier,
)
@@ -143,9 +141,7 @@ async def reset_user_rate_limit(
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id,
config.daily_cost_limit_microdollars,
config.weekly_cost_limit_microdollars,
user_id, config.daily_token_limit, config.weekly_token_limit
)
usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier)
@@ -158,10 +154,10 @@ async def reset_user_rate_limit(
return UserRateLimitResponse(
user_id=user_id,
user_email=resolved_email,
daily_cost_limit_microdollars=daily_limit,
weekly_cost_limit_microdollars=weekly_limit,
daily_cost_used_microdollars=usage.daily.used,
weekly_cost_used_microdollars=usage.weekly.used,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,
weekly_tokens_used=usage.weekly.used,
tier=tier,
)

View File

@@ -57,7 +57,7 @@ def _patch_rate_limit_deps(
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000, SubscriptionTier.BASIC),
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
@@ -85,11 +85,11 @@ def test_get_rate_limit(
data = response.json()
assert data["user_id"] == target_user_id
assert data["user_email"] == _TARGET_EMAIL
assert data["daily_cost_limit_microdollars"] == 2_500_000
assert data["weekly_cost_limit_microdollars"] == 12_500_000
assert data["daily_cost_used_microdollars"] == 500_000
assert data["weekly_cost_used_microdollars"] == 3_000_000
assert data["tier"] == "BASIC"
assert data["daily_token_limit"] == 2_500_000
assert data["weekly_token_limit"] == 12_500_000
assert data["daily_tokens_used"] == 500_000
assert data["weekly_tokens_used"] == 3_000_000
assert data["tier"] == "FREE"
configured_snapshot.assert_match(
json.dumps(data, indent=2, sort_keys=True) + "\n",
@@ -117,7 +117,7 @@ def test_get_rate_limit_by_email(
data = response.json()
assert data["user_id"] == target_user_id
assert data["user_email"] == _TARGET_EMAIL
assert data["daily_cost_limit_microdollars"] == 2_500_000
assert data["daily_token_limit"] == 2_500_000
def test_get_rate_limit_by_email_not_found(
@@ -160,10 +160,10 @@ def test_reset_user_usage_daily_only(
assert response.status_code == 200
data = response.json()
assert data["daily_cost_used_microdollars"] == 0
assert data["daily_tokens_used"] == 0
# Weekly is untouched
assert data["weekly_cost_used_microdollars"] == 3_000_000
assert data["tier"] == "BASIC"
assert data["weekly_tokens_used"] == 3_000_000
assert data["tier"] == "FREE"
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
@@ -192,9 +192,9 @@ def test_reset_user_usage_daily_and_weekly(
assert response.status_code == 200
data = response.json()
assert data["daily_cost_used_microdollars"] == 0
assert data["weekly_cost_used_microdollars"] == 0
assert data["tier"] == "BASIC"
assert data["daily_tokens_used"] == 0
assert data["weekly_tokens_used"] == 0
assert data["tier"] == "FREE"
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
@@ -231,7 +231,7 @@ def test_get_rate_limit_email_lookup_failure(
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000, SubscriptionTier.BASIC),
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
@@ -324,7 +324,7 @@ def test_set_user_tier(
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.BASIC,
return_value=SubscriptionTier.FREE,
)
mock_set = mocker.patch(
f"{_MOCK_MODULE}.set_user_tier",
@@ -347,7 +347,7 @@ def test_set_user_tier_downgrade(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test downgrading a user's tier from PRO to BASIC."""
"""Test downgrading a user's tier from PRO to FREE."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
@@ -365,14 +365,14 @@ def test_set_user_tier_downgrade(
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "BASIC"},
json={"user_id": target_user_id, "tier": "FREE"},
)
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["tier"] == "BASIC"
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.BASIC)
assert data["tier"] == "FREE"
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.FREE)
def test_set_user_tier_invalid_tier(
@@ -456,7 +456,7 @@ def test_set_user_tier_db_failure(
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.BASIC,
return_value=SubscriptionTier.FREE,
)
mocker.patch(
f"{_MOCK_MODULE}.set_user_tier",

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,6 @@ allowing frontend code generators like Orval to create corresponding TypeScript
from pydantic import BaseModel, Field
from backend.data.model import CredentialsType
from backend.integrations.providers import ProviderName
from backend.sdk.registry import AutoRegistry
@@ -48,57 +47,6 @@ class ProviderNamesResponse(BaseModel):
)
class ProviderMetadata(BaseModel):
"""Display metadata for a provider, shown in the settings integrations UI."""
name: str = Field(description="Provider slug (e.g. ``github``)")
description: str | None = Field(
default=None,
description=(
"One-line human-readable summary of what the provider does. "
"Declared via ``ProviderBuilder.with_description(...)`` in the "
"provider's ``_config.py``. ``None`` if not set."
),
)
supported_auth_types: list[CredentialsType] = Field(
default_factory=list,
description=(
"Credential types this provider accepts. Drives which connection "
"tabs the settings UI renders for the provider. Empty list means "
"no auth types declared."
),
)
def get_supported_auth_types(name: str) -> list[CredentialsType]:
"""Return the provider's supported credential types from :class:`AutoRegistry`.
Populated by :meth:`ProviderBuilder.with_supported_auth_types` (or by
``with_oauth`` / ``with_api_key`` / ``with_user_password`` when the provider
uses the full builder chain). Returns an empty list for providers with no
auth types declared.
"""
provider = AutoRegistry.get_provider(name)
if provider is None:
return []
return sorted(provider.supported_auth_types)
def get_provider_description(name: str) -> str | None:
"""Return the provider's description from :class:`AutoRegistry`.
Descriptions are declared via ``ProviderBuilder.with_description(...)`` in
the provider's ``_config.py`` (SDK path) or in
``blocks/_static_provider_configs.py`` (for providers that don't yet have
their own directory). Returns ``None`` for providers with no registered
description.
"""
provider = AutoRegistry.get_provider(name)
if provider is None:
return None
return provider.description
class ProviderConstants(BaseModel):
"""
Model that exposes all provider names as a constant in the OpenAPI schema.

View File

@@ -14,7 +14,7 @@ from fastapi import (
Security,
status,
)
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, SecretStr, model_validator
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
from backend.api.features.library.db import set_preset_webhook, update_preset
@@ -29,14 +29,15 @@ from backend.data.integrations import (
wait_for_webhook_event,
)
from backend.data.model import (
APIKeyCredentials,
Credentials,
CredentialsType,
HostScopedCredentials,
OAuth2Credentials,
UserIntegrations,
is_sdk_default,
)
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
from backend.integrations.credentials_store import (
@@ -47,20 +48,8 @@ from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
create_mcp_oauth_handler,
)
from backend.integrations.managed_credentials import (
ensure_managed_credential,
ensure_managed_credentials,
)
from backend.integrations.managed_providers.ayrshare import AyrshareManagedProvider
from backend.integrations.managed_providers.ayrshare import (
settings_available as ayrshare_settings_available,
)
from backend.integrations.oauth import (
CREDENTIALS_BY_PROVIDER,
DEVICE_HANDLERS_BY_NAME,
HANDLERS_BY_NAME,
)
from backend.integrations.oauth.device_base import BaseDeviceAuthHandler
from backend.integrations.managed_credentials import ensure_managed_credentials
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
from backend.util.exceptions import (
@@ -71,14 +60,7 @@ from backend.util.exceptions import (
)
from backend.util.settings import Settings
from .models import (
ProviderConstants,
ProviderMetadata,
ProviderNamesResponse,
get_all_provider_names,
get_provider_description,
get_supported_auth_types,
)
from .models import ProviderConstants, ProviderNamesResponse, get_all_provider_names
if TYPE_CHECKING:
from backend.integrations.oauth import BaseOAuthHandler
@@ -105,23 +87,14 @@ async def login(
scopes: Annotated[
str, Query(title="Comma-separated list of authorization scopes")
] = "",
credential_id: Annotated[
str | None,
Query(title="ID of existing credential to upgrade scopes for"),
] = None,
) -> LoginResponse:
handler = _get_provider_oauth_handler(request, provider)
requested_scopes = scopes.split(",") if scopes else []
if credential_id:
requested_scopes = await _prepare_scope_upgrade(
user_id, provider, credential_id, requested_scopes
)
# Generate and store a secure random state token along with the scopes
state_token, code_challenge = await creds_manager.store.store_state_token(
user_id, provider, requested_scopes, credential_id=credential_id
user_id, provider, requested_scopes
)
login_url = handler.get_login_url(
requested_scopes, state_token, code_challenge=code_challenge
@@ -243,9 +216,7 @@ async def callback(
)
# TODO: Allow specifying `title` to set on `credentials`
credentials = await _merge_or_create_credential(
user_id, provider, credentials, valid_state.credential_id
)
await creds_manager.create(user_id, credentials)
logger.debug(
f"Successfully processed OAuth callback for user {user_id} "
@@ -255,193 +226,13 @@ async def callback(
return to_meta_response(credentials)
# ================================================================== #
# Device Code Grant endpoints (RFC 8628)
# ================================================================== #
class DeviceAuthInitiateResponse(BaseModel):
state_token: str
device_code: str
user_code: str
verification_url: str
verification_url_complete: str | None = None
expires_in: int
interval: int
class DeviceAuthPollRequest(BaseModel):
state_token: str
class DeviceAuthPollResponse(BaseModel):
status: str
credentials: CredentialsMetaResponse | None = None
def _get_device_auth_handler(provider: ProviderName) -> BaseDeviceAuthHandler:
provider_key = provider.value if hasattr(provider, "value") else str(provider)
if provider_key not in DEVICE_HANDLERS_BY_NAME:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No device-auth handler for provider '{provider_key}'",
)
handler_class = DEVICE_HANDLERS_BY_NAME[provider_key]
return handler_class()
@router.post(
"/{provider}/device-auth/initiate",
summary="Initiate device code OAuth flow",
)
async def device_auth_initiate(
provider: Annotated[
ProviderName,
Path(title="The provider to initiate device auth for"),
],
user_id: Annotated[str, Security(get_user_id)],
scopes: Annotated[
str, Query(title="Comma-separated list of authorization scopes")
] = "",
) -> DeviceAuthInitiateResponse:
handler = _get_device_auth_handler(provider)
requested_scopes = scopes.split(",") if scopes else []
requested_scopes = handler.handle_default_scopes(requested_scopes)
try:
initiation = await handler.initiate_device_auth(requested_scopes)
except Exception as e:
logger.error(f"Device auth initiation failed for {provider}: {e}")
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Failed to initiate device auth: {str(e)}",
)
# Store state with the provider's expiry (not hardcoded 10 min)
state_token, _ = await creds_manager.store.store_state_token(
user_id=user_id,
provider=provider.value if hasattr(provider, "value") else str(provider),
scopes=requested_scopes,
state_metadata={
"flow_type": "device_code",
"device_code": initiation.device_code,
"interval": initiation.interval,
"user_code": initiation.user_code,
},
)
return DeviceAuthInitiateResponse(
state_token=state_token,
device_code=initiation.device_code,
user_code=initiation.user_code,
verification_url=initiation.verification_url,
verification_url_complete=initiation.verification_url_complete,
expires_in=initiation.expires_in,
interval=initiation.interval,
)
@router.post(
"/{provider}/device-auth/poll",
summary="Poll device code OAuth flow for completion",
)
async def device_auth_poll(
provider: Annotated[
ProviderName,
Path(title="The provider to poll device auth for"),
],
body: DeviceAuthPollRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> DeviceAuthPollResponse:
handler = _get_device_auth_handler(provider)
# Non-consuming read — state survives across many polls
valid_state = await creds_manager.store.peek_state_token(
user_id, body.state_token, provider
)
if not valid_state:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired state token",
)
device_code = valid_state.state_metadata.get("device_code")
if not device_code:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="State token is not for a device code flow",
)
try:
result = await handler.poll_for_tokens(device_code)
except Exception as e:
logger.error(f"Device auth poll failed for {provider}: {e}")
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Device auth poll failed: {str(e)}",
)
if result.status in ("pending", "slow_down"):
return DeviceAuthPollResponse(status=result.status)
# Terminal state — consume the token so it can't be reused
await creds_manager.store.consume_state_token(user_id, body.state_token, provider)
if result.status == "approved" and result.credentials:
credentials = result.credentials
credentials.scopes = handler.handle_default_scopes(credentials.scopes)
if len(credentials.scopes) == 1 and " " in credentials.scopes[0]:
credentials.scopes = credentials.scopes[0].split(" ")
credentials = await _merge_or_create_credential(
user_id, provider, credentials, valid_state.credential_id
)
logger.debug(
f"Device auth approved for user {user_id} " f"and provider {provider.value}"
)
return DeviceAuthPollResponse(
status="approved",
credentials=to_meta_response(credentials),
)
# denied / expired
return DeviceAuthPollResponse(status=result.status)
# Bound the first-time sweep so a slow upstream (e.g. Ayrshare) can't hang
# the credential-list endpoint. On timeout we still kick off a fire-and-
# forget sweep so provisioning eventually completes; the user just won't
# see the managed cred until the next refresh.
_MANAGED_PROVISION_TIMEOUT_S = 10.0
async def _ensure_managed_credentials_bounded(user_id: str) -> None:
try:
await asyncio.wait_for(
ensure_managed_credentials(user_id, creds_manager.store),
timeout=_MANAGED_PROVISION_TIMEOUT_S,
)
except asyncio.TimeoutError:
logger.warning(
"Managed credential sweep exceeded %.1fs for user=%s; "
"continuing without it — provisioning will complete in background",
_MANAGED_PROVISION_TIMEOUT_S,
user_id,
)
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
@router.get("/credentials", summary="List Credentials")
async def list_credentials(
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
# Block on provisioning so managed credentials appear on the first load
# instead of after a refresh, but with a timeout so a slow upstream
# can't hang the endpoint. `_provisioned_users` short-circuits on
# repeat calls.
await _ensure_managed_credentials_bounded(user_id)
# Fire-and-forget: provision missing managed credentials in the background.
# The credential appears on the next page load; listing is never blocked.
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
credentials = await creds_manager.store.get_all_creds(user_id)
return [
@@ -456,7 +247,7 @@ async def list_credentials_by_provider(
],
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
await _ensure_managed_credentials_bounded(user_id)
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
return [
@@ -490,115 +281,6 @@ async def get_credential(
return to_meta_response(credential)
class PickerTokenResponse(BaseModel):
"""Short-lived OAuth access token shipped to the browser for rendering a
provider-hosted picker UI (e.g. Google Drive Picker). Deliberately narrow:
only the fields the client needs to initialize the picker widget. Issued
from the user's own stored credential so ownership and scope gating are
enforced by the credential lookup."""
access_token: str = Field(
description="OAuth access token suitable for the picker SDK call."
)
access_token_expires_at: int | None = Field(
default=None,
description="Unix timestamp at which the access token expires, if known.",
)
# Allowlist of (provider, scopes) tuples that may mint picker tokens. Only
# Drive-picker-capable scopes qualify so a caller can't use this endpoint to
# extract a GitHub / other-provider OAuth token for unrelated purposes. If a
# future provider integrates a hosted picker that needs a raw access token,
# add its specific picker-relevant scopes here.
_PICKER_TOKEN_ALLOWED_SCOPES: dict[ProviderName, frozenset[str]] = {
ProviderName.GOOGLE: frozenset(
[
"https://www.googleapis.com/auth/drive.file",
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive",
]
),
}
@router.post(
"/{provider}/credentials/{cred_id}/picker-token",
summary="Issue a short-lived access token for a provider-hosted picker",
operation_id="postV1GetPickerToken",
)
async def get_picker_token(
provider: Annotated[
ProviderName, Path(title="The provider that owns the credentials")
],
cred_id: Annotated[
str, Path(title="The ID of the OAuth2 credentials to mint a token from")
],
user_id: Annotated[str, Security(get_user_id)],
) -> PickerTokenResponse:
"""Return the raw access token for an OAuth2 credential so the frontend
can initialize a provider-hosted picker (e.g. Google Drive Picker).
`GET /{provider}/credentials/{cred_id}` deliberately strips secrets (see
`CredentialsMetaResponse` + `TestGetCredentialReturnsMetaOnly` in
`router_test.py`). That hardening broke the Drive picker, which needs the
raw access token to call `google.picker.Builder.setOAuthToken(...)`. This
endpoint carves a narrow, explicit hole: the caller must own the
credential, it must be OAuth2, and the endpoint returns only the access
token + its expiry — nothing else about the credential. SDK-default
credentials are excluded for the same reason as `get_credential`.
"""
if is_sdk_default(cred_id):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
credential = await creds_manager.get(user_id, cred_id)
if not credential:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if not provider_matches(credential.provider, provider):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if not isinstance(credential, OAuth2Credentials):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Picker tokens are only available for OAuth2 credentials",
)
if not credential.access_token:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Credential has no access token; reconnect the account",
)
# Gate on provider+scope: only credentials that actually grant access to
# a provider-hosted picker flow may mint a token through this endpoint.
# Prevents using this path to extract bearer tokens for unrelated OAuth
# integrations (e.g. GitHub) that happen to be stored under the same user.
allowed_scopes = _PICKER_TOKEN_ALLOWED_SCOPES.get(provider)
if not allowed_scopes:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(f"Picker tokens are not available for provider '{provider.value}'"),
)
cred_scopes = set(credential.scopes or [])
if cred_scopes.isdisjoint(allowed_scopes):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
"Credential does not grant any scope eligible for the picker. "
"Reconnect with the appropriate scope."
),
)
return PickerTokenResponse(
access_token=credential.access_token.get_secret_value(),
access_token_expires_at=credential.access_token_expires_at,
)
@router.post("/{provider}/credentials", status_code=201, summary="Create Credentials")
async def create_credentials(
user_id: Annotated[str, Security(get_user_id)],
@@ -892,186 +574,6 @@ async def _execute_webhook_preset_trigger(
# Continue processing - webhook should be resilient to individual failures
# -------------------- INCREMENTAL AUTH HELPERS -------------------- #
async def _prepare_scope_upgrade(
user_id: str,
provider: ProviderName,
credential_id: str,
requested_scopes: list[str],
) -> list[str]:
"""Validate an existing credential for scope upgrade and compute scopes.
For providers without native incremental auth (e.g. GitHub), returns the
union of existing + requested scopes. For providers that handle merging
server-side (e.g. Google with ``include_granted_scopes``), returns the
requested scopes unchanged.
Raises HTTPException on validation failure.
"""
# Platform-owned system credentials must never be upgraded — scope
# changes here would leak across every user that shares them.
if is_system_credential(credential_id):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="System credentials cannot be upgraded",
)
existing = await creds_manager.store.get_creds_by_id(user_id, credential_id)
if not existing:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Credential to upgrade not found",
)
if not isinstance(existing, OAuth2Credentials):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only OAuth2 credentials can be upgraded",
)
if not provider_matches(existing.provider, provider.value):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Credential provider does not match the requested provider",
)
if existing.is_managed:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Managed credentials cannot be upgraded",
)
# Google handles scope merging via include_granted_scopes; others need
# the union of existing + new scopes in the login URL.
if provider != ProviderName.GOOGLE:
requested_scopes = list(set(requested_scopes) | set(existing.scopes))
return requested_scopes
async def _merge_or_create_credential(
user_id: str,
provider: ProviderName,
credentials: OAuth2Credentials,
credential_id: str | None,
) -> OAuth2Credentials:
"""Either upgrade an existing credential or create a new one.
When *credential_id* is set (explicit upgrade), merges scopes and updates
the existing credential. Otherwise, checks for an implicit merge (same
provider + username) before falling back to creating a new credential.
"""
if credential_id:
return await _upgrade_existing_credential(user_id, credential_id, credentials)
# Implicit merge: check for existing credential with same provider+username.
# Skip managed/system credentials and require a non-None username on both
# sides so we never accidentally merge unrelated credentials.
if credentials.username is None:
await creds_manager.create(user_id, credentials)
return credentials
existing_creds = await creds_manager.store.get_creds_by_provider(user_id, provider)
matching = next(
(
c
for c in existing_creds
if isinstance(c, OAuth2Credentials)
and not c.is_managed
and not is_system_credential(c.id)
and c.username is not None
and c.username == credentials.username
),
None,
)
if matching:
# Only merge into the existing credential when the new token
# already covers every scope we're about to advertise on it.
# Without this guard we'd overwrite ``matching.access_token`` with
# a narrower token while storing a wider ``scopes`` list — the
# record would claim authorizations the token does not grant, and
# blocks using the lost scopes would fail with opaque 401/403s
# until the user hits re-auth. On a narrowing login, keep the
# two credentials separate instead.
if set(credentials.scopes).issuperset(set(matching.scopes)):
return await _upgrade_existing_credential(user_id, matching.id, credentials)
await creds_manager.create(user_id, credentials)
return credentials
async def _upgrade_existing_credential(
user_id: str,
existing_cred_id: str,
new_credentials: OAuth2Credentials,
) -> OAuth2Credentials:
"""Merge scopes from *new_credentials* into an existing credential."""
# Defense-in-depth: re-check system and provider invariants right before
# the write. The login-time check in `_prepare_scope_upgrade` can go stale
# by the time the callback runs, and the implicit-merge path bypasses
# login-time validation entirely, so every write-path must enforce these
# on its own.
if is_system_credential(existing_cred_id):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="System credentials cannot be upgraded",
)
existing = await creds_manager.store.get_creds_by_id(user_id, existing_cred_id)
if not existing or not isinstance(existing, OAuth2Credentials):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Credential to upgrade not found",
)
if existing.is_managed:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Managed credentials cannot be upgraded",
)
if not provider_matches(existing.provider, new_credentials.provider):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Credential provider does not match the requested provider",
)
if (
existing.username
and new_credentials.username
and existing.username != new_credentials.username
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username mismatch: authenticated as a different user",
)
# Operate on a copy so the caller's ``new_credentials`` object is not
# mutated out from under them. Every caller today immediately discards
# or replaces its reference, but the implicit-merge path in
# ``_merge_or_create_credential`` reads ``credentials.scopes`` before
# calling into us — a future reader after the call would otherwise
# silently see the overwritten values.
merged = new_credentials.model_copy(deep=True)
merged.id = existing.id
merged.title = existing.title
merged.scopes = list(set(existing.scopes) | set(new_credentials.scopes))
merged.metadata = {
**(existing.metadata or {}),
**(new_credentials.metadata or {}),
}
# Preserve the existing refresh_token and username if the incremental
# response doesn't carry them. Providers like Google only return a
# refresh_token on first authorization — dropping it here would orphan
# the credential on the next access-token expiry, forcing the user to
# re-auth from scratch. Username is similarly sticky: if we've already
# resolved it for this credential, keep it rather than silently
# blanking it on an incremental upgrade.
if not merged.refresh_token and existing.refresh_token:
merged.refresh_token = existing.refresh_token
merged.refresh_token_expires_at = existing.refresh_token_expires_at
if not merged.username and existing.username:
merged.username = existing.username
await creds_manager.update(user_id, merged)
return merged
# --------------------------- UTILITIES ---------------------------- #
@@ -1282,21 +784,12 @@ def _get_provider_oauth_handler(
async def get_ayrshare_sso_url(
user_id: Annotated[str, Security(get_user_id)],
) -> AyrshareSSOResponse:
"""Generate a JWT SSO URL so the user can link their social accounts.
The per-user Ayrshare profile key is provisioned and persisted as a
standard ``is_managed=True`` credential by
:class:`~backend.integrations.managed_providers.ayrshare.AyrshareManagedProvider`.
This endpoint only signs a short-lived JWT pointing at the Ayrshare-
hosted social-linking page; all profile lifecycle logic lives with the
managed provider.
"""
if not ayrshare_settings_available():
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="Ayrshare integration is not configured",
)
Generate an SSO URL for Ayrshare social media integration.
Returns:
dict: Contains the SSO URL for Ayrshare integration
"""
try:
client = AyrshareClient()
except MissingConfigError:
@@ -1305,63 +798,66 @@ async def get_ayrshare_sso_url(
detail="Ayrshare integration is not configured",
)
# On-demand provisioning: AyrshareManagedProvider opts out of the
# credentials sweep (profile quota is per-user subscription-bound). This
# endpoint is the only trigger that provisions a profile — one Ayrshare
# profile per user who actually opens the connect flow, not one per
# every authenticated user.
provisioned = await ensure_managed_credential(
user_id, creds_manager.store, AyrshareManagedProvider()
)
if not provisioned:
raise HTTPException(
status_code=HTTP_502_BAD_GATEWAY,
detail="Failed to provision Ayrshare profile",
)
# Ayrshare profile key is stored in the credentials store
# It is generated when creating a new profile, if there is no profile key,
# we create a new profile and store the profile key in the credentials store
ayrshare_creds = [
c
for c in await creds_manager.store.get_creds_by_provider(user_id, "ayrshare")
if c.is_managed and isinstance(c, APIKeyCredentials)
]
if not ayrshare_creds:
logger.error(
"Ayrshare credential provisioning did not produce a credential "
"for user %s",
user_id,
)
raise HTTPException(
status_code=HTTP_502_BAD_GATEWAY,
detail="Failed to provision Ayrshare profile",
)
profile_key_str = ayrshare_creds[0].api_key.get_secret_value()
user_integrations: UserIntegrations = await get_user_integrations(user_id)
profile_key = user_integrations.managed_credentials.ayrshare_profile_key
if not profile_key:
logger.debug(f"Creating new Ayrshare profile for user {user_id}")
try:
profile = await client.create_profile(
title=f"User {user_id}", messaging_active=True
)
profile_key = profile.profileKey
await creds_manager.store.set_ayrshare_profile_key(user_id, profile_key)
except Exception as e:
logger.error(f"Error creating Ayrshare profile for user {user_id}: {e}")
raise HTTPException(
status_code=HTTP_502_BAD_GATEWAY,
detail="Failed to create Ayrshare profile",
)
else:
logger.debug(f"Using existing Ayrshare profile for user {user_id}")
profile_key_str = (
profile_key.get_secret_value()
if isinstance(profile_key, SecretStr)
else str(profile_key)
)
private_key = settings.secrets.ayrshare_jwt_key
# Ayrshare JWT max lifetime is 2880 minutes (48 h).
# Ayrshare JWT expiry is 2880 minutes (48 hours)
max_expiry_minutes = 2880
try:
logger.debug(f"Generating Ayrshare JWT for user {user_id}")
jwt_response = await client.generate_jwt(
private_key=private_key,
profile_key=profile_key_str,
# `allowed_social` is the set of networks the Ayrshare-hosted
# social-linking page will *offer* the user to connect. Blocks
# exist for more platforms than are listed here; the list is
# deliberately narrower so the rollout can verify each network
# end-to-end before widening the user-visible surface. Keep
# in sync with tested platforms — extend as each is verified
# against the block + Ayrshare's network-specific quirks.
allowed_social=[
# NOTE: We are enabling platforms one at a time
# to speed up the development process
# SocialPlatform.FACEBOOK,
SocialPlatform.TWITTER,
SocialPlatform.LINKEDIN,
SocialPlatform.INSTAGRAM,
SocialPlatform.YOUTUBE,
# SocialPlatform.REDDIT,
# SocialPlatform.TELEGRAM,
# SocialPlatform.GOOGLE_MY_BUSINESS,
# SocialPlatform.PINTEREST,
SocialPlatform.TIKTOK,
# SocialPlatform.BLUESKY,
# SocialPlatform.SNAPCHAT,
# SocialPlatform.THREADS,
],
expires_in=max_expiry_minutes,
verify=True,
)
except Exception as exc:
logger.error("Error generating Ayrshare JWT for user %s: %s", user_id, exc)
except Exception as e:
logger.error(f"Error generating Ayrshare JWT for user {user_id}: {e}")
raise HTTPException(
status_code=HTTP_502_BAD_GATEWAY, detail="Failed to generate JWT"
)
@@ -1371,37 +867,20 @@ async def get_ayrshare_sso_url(
# === PROVIDER DISCOVERY ENDPOINTS ===
@router.get("/providers", response_model=List[ProviderMetadata])
async def list_providers() -> List[ProviderMetadata]:
@router.get("/providers", response_model=List[str])
async def list_providers() -> List[str]:
"""
Get metadata for every available provider.
Get a list of all available provider names.
Returns both statically defined providers (from ``ProviderName`` enum) and
dynamically registered providers (from SDK decorators). Each entry includes
a ``description`` declared via ``ProviderBuilder.with_description(...)`` in
the provider's ``_config.py``.
Returns both statically defined providers (from ProviderName enum)
and dynamically registered providers (from SDK decorators).
Note: The complete list of provider names is also available as a constant
in the generated TypeScript client via PROVIDER_NAMES.
"""
# Ensure all block modules (and therefore every provider's _config.py) are
# imported before we read from AutoRegistry. Cached on first call.
try:
from backend.blocks import load_all_blocks
load_all_blocks()
except Exception as e:
logger.warning(f"Failed to load blocks for provider metadata: {e}")
# Get all providers at runtime
all_providers = get_all_provider_names()
return [
ProviderMetadata(
name=name,
description=get_provider_description(name),
supported_auth_types=get_supported_auth_types(name),
)
for name in all_providers
]
return all_providers
@router.get("/providers/system", response_model=List[str])

View File

@@ -393,7 +393,7 @@ class TestEnsureManagedCredentials:
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_awaited_once_with("user-1", store)
provider.provision.assert_awaited_once_with("user-1")
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
@pytest.mark.asyncio
@@ -568,181 +568,3 @@ class TestCleanupManagedCredentials:
_PROVIDERS.update(saved)
# No exception raised — cleanup failure is swallowed.
class TestGetPickerToken:
"""POST /{provider}/credentials/{cred_id}/picker-token must:
1. Return the access token for OAuth2 creds the caller owns.
2. 404 for non-owned, non-existent, or wrong-provider creds.
3. 400 for non-OAuth2 creds (API key, host-scoped, user/password).
4. 404 for SDK default creds (same hardening as get_credential).
5. Preserve the `TestGetCredentialReturnsMetaOnly` contract — the
existing meta-only endpoint must still strip secrets even after
this picker-token endpoint exists."""
def test_oauth2_owner_gets_access_token(self):
# Use a Google cred with a drive.file scope — only picker-eligible
# (provider, scope) pairs can mint a token. GitHub-style creds are
# explicitly rejected; see `test_non_picker_provider_rejected_as_400`.
cred = _make_oauth2_cred(
cred_id="cred-gdrive",
provider="google",
)
cred.scopes = ["https://www.googleapis.com/auth/drive.file"]
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.post("/google/credentials/cred-gdrive/picker-token")
assert resp.status_code == 200
data = resp.json()
# The whole point of this endpoint: the access token IS returned here.
assert data["access_token"] == "ghp_secret_token"
# Only the two declared fields come back — nothing else leaks.
assert set(data.keys()) <= {"access_token", "access_token_expires_at"}
def test_non_picker_provider_rejected_as_400(self):
"""Provider allowlist: even with a valid OAuth2 credential, a
non-picker provider (GitHub, etc.) cannot mint a picker token.
Stops this endpoint from being used as a generic bearer-token
extraction path for any stored OAuth cred under the same user."""
cred = _make_oauth2_cred(provider="github")
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.post("/github/credentials/cred-456/picker-token")
assert resp.status_code == 400
assert "not available for provider" in resp.json()["detail"]
assert "ghp_secret_token" not in str(resp.json())
def test_google_oauth_without_drive_scope_rejected(self):
"""Scope allowlist: a Google OAuth2 cred that only carries non-picker
scopes (e.g. gmail.readonly, calendar) cannot mint a picker token.
Forces the frontend to reconnect with a Drive scope before the
picker is available."""
cred = _make_oauth2_cred(provider="google")
cred.scopes = [
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/calendar",
]
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.post("/google/credentials/cred-456/picker-token")
assert resp.status_code == 400
assert "picker" in resp.json()["detail"].lower()
def test_api_key_credential_rejected_as_400(self):
cred = _make_api_key_cred()
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.post("/openai/credentials/cred-123/picker-token")
assert resp.status_code == 400
# API keys must not silently fall through to a 200 response of some
# other shape — the client should see a clear shape rejection.
body = str(resp.json())
assert "sk-secret-key-value" not in body
def test_user_password_credential_rejected_as_400(self):
cred = _make_user_password_cred()
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.post("/openai/credentials/cred-789/picker-token")
assert resp.status_code == 400
body = str(resp.json())
assert "s3cret-pass" not in body
assert "admin" not in body
def test_host_scoped_credential_rejected_as_400(self):
cred = _make_host_scoped_cred()
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.post("/openai/credentials/cred-host/picker-token")
assert resp.status_code == 400
assert "top-secret" not in str(resp.json())
def test_missing_credential_returns_404(self):
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=None)
resp = client.post("/github/credentials/nonexistent/picker-token")
assert resp.status_code == 404
assert resp.json()["detail"] == "Credentials not found"
def test_wrong_provider_returns_404(self):
"""Symmetric with get_credential: provider mismatch is a generic
404, not a 400, so we don't leak existence of a credential the
caller doesn't own on that provider."""
cred = _make_oauth2_cred(provider="github")
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.post("/google/credentials/cred-456/picker-token")
assert resp.status_code == 404
assert resp.json()["detail"] == "Credentials not found"
def test_sdk_default_returns_404(self):
"""SDK defaults are invisible to the user-facing API — picker-token
must not mint a token for them either."""
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock()
resp = client.post("/openai/credentials/openai-default/picker-token")
assert resp.status_code == 404
mock_mgr.get.assert_not_called()
def test_oauth2_without_access_token_returns_400(self):
"""A stored OAuth2 cred whose access_token is missing can't satisfy
a picker init. Surface a clear reconnect instruction rather than
returning an empty string."""
cred = _make_oauth2_cred()
# Simulate a cred that lost its access token
object.__setattr__(cred, "access_token", None)
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.post("/github/credentials/cred-456/picker-token")
assert resp.status_code == 400
assert "reconnect" in resp.json()["detail"].lower()
def test_meta_only_endpoint_still_strips_access_token(self):
"""Regression guard for the coexistence contract: the new
picker-token endpoint must NOT accidentally leak the token through
the meta-only GET endpoint. TestGetCredentialReturnsMetaOnly
covers this more broadly; this is a fast sanity check co-located
with the new endpoint's tests."""
cred = _make_oauth2_cred()
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.get("/github/credentials/cred-456")
assert resp.status_code == 200
body = resp.json()
assert "access_token" not in body
assert "refresh_token" not in body
assert "ghp_secret_token" not in str(body)

View File

@@ -12,7 +12,6 @@ import prisma.models
import backend.api.features.library.model as library_model
import backend.data.graph as graph_db
from backend.api.features.library.db import _fetch_schedule_info
from backend.data.graph import GraphModel, GraphSettings
from backend.data.includes import library_agent_include
from backend.util.exceptions import NotFoundError
@@ -118,5 +117,4 @@ async def add_graph_to_library(
f"for store listing version #{store_listing_version_id} "
f"to library for user #{user_id}"
)
schedule_info = await _fetch_schedule_info(user_id, graph_id=graph_model.id)
return library_model.LibraryAgent.from_db(added_agent, schedule_info=schedule_info)
return library_model.LibraryAgent.from_db(added_agent)

View File

@@ -21,17 +21,13 @@ async def test_add_graph_to_library_create_new_agent() -> None:
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
return_value=converted_agent,
) as mock_from_db,
patch(
"backend.api.features.library._add_to_library._fetch_schedule_info",
new=AsyncMock(return_value={}),
),
):
mock_prisma.return_value.create = AsyncMock(return_value=created_agent)
result = await add_graph_to_library("slv-id", graph_model, "user-id")
assert result is converted_agent
mock_from_db.assert_called_once_with(created_agent, schedule_info={})
mock_from_db.assert_called_once_with(created_agent)
# Verify create was called with correct data
create_call = mock_prisma.return_value.create.call_args
create_data = create_call.kwargs["data"]
@@ -58,10 +54,6 @@ async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
return_value=converted_agent,
) as mock_from_db,
patch(
"backend.api.features.library._add_to_library._fetch_schedule_info",
new=AsyncMock(return_value={}),
),
):
mock_prisma.return_value.create = AsyncMock(
side_effect=prisma.errors.UniqueViolationError(
@@ -73,7 +65,7 @@ async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
result = await add_graph_to_library("slv-id", graph_model, "user-id")
assert result is converted_agent
mock_from_db.assert_called_once_with(updated_agent, schedule_info={})
mock_from_db.assert_called_once_with(updated_agent)
# Verify update was called with correct where and data
update_call = mock_prisma.return_value.update.call_args
assert update_call.kwargs["where"] == {

View File

@@ -1,7 +1,6 @@
import asyncio
import itertools
import logging
from datetime import datetime, timezone
from typing import Literal, Optional
import fastapi
@@ -44,65 +43,6 @@ config = Config()
integration_creds_manager = IntegrationCredentialsManager()
async def _fetch_execution_counts(user_id: str, graph_ids: list[str]) -> dict[str, int]:
"""Fetch execution counts per graph in a single batched query."""
if not graph_ids:
return {}
rows = await prisma.models.AgentGraphExecution.prisma().group_by(
by=["agentGraphId"],
where={
"userId": user_id,
"agentGraphId": {"in": graph_ids},
"isDeleted": False,
},
count=True,
)
return {
row["agentGraphId"]: int((row.get("_count") or {}).get("_all") or 0)
for row in rows
}
async def _fetch_schedule_info(
user_id: str, graph_id: Optional[str] = None
) -> dict[str, str]:
"""Fetch a map of graph_id → earliest next_run_time ISO string.
When `graph_id` is provided, the scheduler query is narrowed to that graph,
which is cheaper for single-agent lookups (detail page, post-update, etc.).
"""
try:
scheduler_client = get_scheduler_client()
schedules = await scheduler_client.get_execution_schedules(
graph_id=graph_id,
user_id=user_id,
)
earliest: dict[str, tuple[datetime, str]] = {}
for s in schedules:
parsed = _parse_iso_datetime(s.next_run_time)
if parsed is None:
continue
current = earliest.get(s.graph_id)
if current is None or parsed < current[0]:
earliest[s.graph_id] = (parsed, s.next_run_time)
return {graph_id: iso for graph_id, (_, iso) in earliest.items()}
except Exception:
logger.warning("Failed to fetch schedules for library agents", exc_info=True)
return {}
def _parse_iso_datetime(value: str) -> Optional[datetime]:
"""Parse an ISO 8601 datetime, tolerating `Z` and naive forms (assumed UTC)."""
try:
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
logger.warning("Failed to parse schedule next_run_time: %s", value)
return None
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed
async def list_library_agents(
user_id: str,
search_term: Optional[str] = None,
@@ -197,22 +137,12 @@ async def list_library_agents(
logger.debug(f"Retrieved {len(library_agents)} library agents for user #{user_id}")
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
execution_counts, schedule_info = await asyncio.gather(
_fetch_execution_counts(user_id, graph_ids),
_fetch_schedule_info(user_id),
)
# Only pass valid agents to the response
valid_library_agents: list[library_model.LibraryAgent] = []
for agent in library_agents:
try:
library_agent = library_model.LibraryAgent.from_db(
agent,
execution_count_override=execution_counts.get(agent.agentGraphId),
schedule_info=schedule_info,
)
library_agent = library_model.LibraryAgent.from_db(agent)
valid_library_agents.append(library_agent)
except Exception as e:
# Skip this agent if there was an error
@@ -284,22 +214,12 @@ async def list_favorite_library_agents(
f"Retrieved {len(library_agents)} favorite library agents for user #{user_id}"
)
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
execution_counts, schedule_info = await asyncio.gather(
_fetch_execution_counts(user_id, graph_ids),
_fetch_schedule_info(user_id),
)
# Only pass valid agents to the response
valid_library_agents: list[library_model.LibraryAgent] = []
for agent in library_agents:
try:
library_agent = library_model.LibraryAgent.from_db(
agent,
execution_count_override=execution_counts.get(agent.agentGraphId),
schedule_info=schedule_info,
)
library_agent = library_model.LibraryAgent.from_db(agent)
valid_library_agents.append(library_agent)
except Exception as e:
# Skip this agent if there was an error
@@ -365,12 +285,6 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
where={"userId": store_listing.owningUserId}
)
schedule_info = (
await _fetch_schedule_info(user_id, graph_id=library_agent.AgentGraph.id)
if library_agent.AgentGraph
else {}
)
return library_model.LibraryAgent.from_db(
library_agent,
sub_graphs=(
@@ -380,7 +294,6 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
),
store_listing=store_listing,
profile=profile,
schedule_info=schedule_info,
)
@@ -416,10 +329,7 @@ async def get_library_agent_by_store_version_id(
},
include=library_agent_include(user_id),
)
if not agent:
return None
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId)
return library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
return library_model.LibraryAgent.from_db(agent) if agent else None
async def get_library_agent_by_graph_id(
@@ -448,10 +358,7 @@ async def get_library_agent_by_graph_id(
assert agent.AgentGraph # make type checker happy
# Include sub-graphs so we can make a full credentials input schema
sub_graphs = await graph_db.get_sub_graphs(agent.AgentGraph)
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId)
return library_model.LibraryAgent.from_db(
agent, sub_graphs=sub_graphs, schedule_info=schedule_info
)
return library_model.LibraryAgent.from_db(agent, sub_graphs=sub_graphs)
async def add_generated_agent_image(
@@ -593,11 +500,7 @@ async def create_library_agent(
for agent, graph in zip(library_agents, graph_entries):
asyncio.create_task(add_generated_agent_image(graph, user_id, agent.id))
schedule_info = await _fetch_schedule_info(user_id)
return [
library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
for agent in library_agents
]
return [library_model.LibraryAgent.from_db(agent) for agent in library_agents]
async def update_agent_version_in_library(
@@ -659,8 +562,7 @@ async def update_agent_version_in_library(
f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}"
)
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent_graph_id)
return library_model.LibraryAgent.from_db(lib, schedule_info=schedule_info)
return library_model.LibraryAgent.from_db(lib)
async def create_graph_in_library(
@@ -743,7 +645,6 @@ async def update_library_agent_version_and_settings(
graph=agent_graph,
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
builder_chat_session_id=library.settings.builder_chat_session_id,
)
if updated_settings != library.settings:
library = await update_library_agent(
@@ -1566,11 +1467,7 @@ async def bulk_move_agents_to_folder(
),
)
schedule_info = await _fetch_schedule_info(user_id)
return [
library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
for agent in agents
]
return [library_model.LibraryAgent.from_db(agent) for agent in agents]
def collect_tree_ids(
@@ -1804,7 +1701,7 @@ async def create_preset_from_graph_execution(
raise NotFoundError(
f"Graph #{graph_execution.graph_id} not found or accessible"
)
elif len(graph.regular_credentials_inputs) > 0:
elif len(graph.aggregate_credentials_inputs()) > 0:
raise ValueError(
f"Graph execution #{graph_exec_id} can't be turned into a preset "
"because it was run before this feature existed "

View File

@@ -65,11 +65,6 @@ async def test_get_library_agents(mocker):
)
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
mocker.patch(
"backend.api.features.library.db._fetch_execution_counts",
new=mocker.AsyncMock(return_value={}),
)
# Call function
result = await db.list_library_agents("test-user")
@@ -358,136 +353,3 @@ async def test_create_library_agent_uses_upsert():
# Verify update branch restores soft-deleted/archived agents
assert data["update"]["isDeleted"] is False
assert data["update"]["isArchived"] is False
@pytest.mark.asyncio
async def test_list_favorite_library_agents(mocker):
mock_library_agents = [
prisma.models.LibraryAgent(
id="fav1",
userId="test-user",
agentGraphId="agent-fav",
settings="{}", # type: ignore
agentGraphVersion=1,
isCreatedByUser=False,
isDeleted=False,
isArchived=False,
createdAt=datetime.now(),
updatedAt=datetime.now(),
isFavorite=True,
useGraphIsActiveVersion=True,
AgentGraph=prisma.models.AgentGraph(
id="agent-fav",
version=1,
name="Favorite Agent",
description="My Favorite",
userId="other-user",
isActive=True,
createdAt=datetime.now(),
),
)
]
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_many = mocker.AsyncMock(
return_value=mock_library_agents
)
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
mocker.patch(
"backend.api.features.library.db._fetch_execution_counts",
new=mocker.AsyncMock(return_value={"agent-fav": 7}),
)
result = await db.list_favorite_library_agents("test-user")
assert len(result.agents) == 1
assert result.agents[0].id == "fav1"
assert result.agents[0].name == "Favorite Agent"
assert result.agents[0].graph_id == "agent-fav"
assert result.pagination.total_items == 1
assert result.pagination.total_pages == 1
assert result.pagination.current_page == 1
assert result.pagination.page_size == 50
@pytest.mark.asyncio
async def test_list_library_agents_skips_failed_agent(mocker):
"""Agents that fail parsing should be skipped — covers the except branch."""
mock_library_agents = [
prisma.models.LibraryAgent(
id="ua-bad",
userId="test-user",
agentGraphId="agent-bad",
settings="{}", # type: ignore
agentGraphVersion=1,
isCreatedByUser=False,
isDeleted=False,
isArchived=False,
createdAt=datetime.now(),
updatedAt=datetime.now(),
isFavorite=False,
useGraphIsActiveVersion=True,
AgentGraph=prisma.models.AgentGraph(
id="agent-bad",
version=1,
name="Bad Agent",
description="",
userId="other-user",
isActive=True,
createdAt=datetime.now(),
),
)
]
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_many = mocker.AsyncMock(
return_value=mock_library_agents
)
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
mocker.patch(
"backend.api.features.library.db._fetch_execution_counts",
new=mocker.AsyncMock(return_value={}),
)
mocker.patch(
"backend.api.features.library.model.LibraryAgent.from_db",
side_effect=Exception("parse error"),
)
result = await db.list_library_agents("test-user")
assert len(result.agents) == 0
assert result.pagination.total_items == 1
@pytest.mark.asyncio
async def test_fetch_execution_counts_empty_graph_ids():
result = await db._fetch_execution_counts("user-1", [])
assert result == {}
@pytest.mark.asyncio
async def test_fetch_execution_counts_uses_group_by(mocker):
mock_prisma = mocker.patch("prisma.models.AgentGraphExecution.prisma")
mock_prisma.return_value.group_by = mocker.AsyncMock(
return_value=[
{"agentGraphId": "graph-1", "_count": {"_all": 5}},
{"agentGraphId": "graph-2", "_count": {"_all": 2}},
]
)
result = await db._fetch_execution_counts(
"user-1", ["graph-1", "graph-2", "graph-3"]
)
assert result == {"graph-1": 5, "graph-2": 2}
mock_prisma.return_value.group_by.assert_called_once_with(
by=["agentGraphId"],
where={
"userId": "user-1",
"agentGraphId": {"in": ["graph-1", "graph-2", "graph-3"]},
"isDeleted": False,
},
count=True,
)

View File

@@ -214,14 +214,6 @@ class LibraryAgent(pydantic.BaseModel):
folder_name: str | None = None # Denormalized for display
recommended_schedule_cron: str | None = None
is_scheduled: bool = pydantic.Field(
default=False,
description="Whether this agent has active execution schedules",
)
next_scheduled_run: str | None = pydantic.Field(
default=None,
description="ISO 8601 timestamp of the next scheduled run, if any",
)
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
marketplace_listing: Optional["MarketplaceListing"] = None
@@ -231,8 +223,6 @@ class LibraryAgent(pydantic.BaseModel):
sub_graphs: Optional[list[prisma.models.AgentGraph]] = None,
store_listing: Optional[prisma.models.StoreListing] = None,
profile: Optional[prisma.models.Profile] = None,
execution_count_override: Optional[int] = None,
schedule_info: Optional[dict[str, str]] = None,
) -> "LibraryAgent":
"""
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
@@ -268,14 +258,10 @@ class LibraryAgent(pydantic.BaseModel):
status = status_result.status
new_output = status_result.new_output
execution_count = (
execution_count_override
if execution_count_override is not None
else len(executions)
)
execution_count = len(executions)
success_rate: float | None = None
avg_correctness_score: float | None = None
if executions and execution_count > 0:
if execution_count > 0:
success_count = sum(
1
for e in executions
@@ -368,10 +354,6 @@ class LibraryAgent(pydantic.BaseModel):
folder_id=agent.folderId,
folder_name=agent.Folder.name if agent.Folder else None,
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
is_scheduled=bool(schedule_info and agent.agentGraphId in schedule_info),
next_scheduled_run=(
schedule_info.get(agent.agentGraphId) if schedule_info else None
),
settings=_parse_settings(agent.settings),
marketplace_listing=marketplace_listing_data,
)

View File

@@ -1,66 +1,11 @@
import datetime
import prisma.enums
import prisma.models
import pytest
from . import model as library_model
def _make_library_agent(
*,
graph_id: str = "g1",
executions: list | None = None,
) -> prisma.models.LibraryAgent:
return prisma.models.LibraryAgent(
id="la1",
userId="u1",
agentGraphId=graph_id,
settings="{}", # type: ignore
agentGraphVersion=1,
isCreatedByUser=True,
isDeleted=False,
isArchived=False,
createdAt=datetime.datetime.now(),
updatedAt=datetime.datetime.now(),
isFavorite=False,
useGraphIsActiveVersion=True,
AgentGraph=prisma.models.AgentGraph(
id=graph_id,
version=1,
name="Agent",
description="Desc",
userId="u1",
isActive=True,
createdAt=datetime.datetime.now(),
Executions=executions,
),
)
def test_from_db_execution_count_override_covers_success_rate():
"""Covers execution_count_override is not None branch and executions/count > 0 block."""
now = datetime.datetime.now(datetime.timezone.utc)
exec1 = prisma.models.AgentGraphExecution(
id="exec-1",
agentGraphId="g1",
agentGraphVersion=1,
userId="u1",
executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED,
createdAt=now,
updatedAt=now,
isDeleted=False,
isShared=False,
)
agent = _make_library_agent(executions=[exec1])
result = library_model.LibraryAgent.from_db(agent, execution_count_override=1)
assert result.execution_count == 1
assert result.success_rate is not None
assert result.success_rate == 100.0
@pytest.mark.asyncio
async def test_agent_preset_from_db(test_user_id: str):
# Create mock DB agent

View File

@@ -1 +0,0 @@
"""Platform bot linking — user-facing REST routes."""

View File

@@ -1,158 +0,0 @@
"""User-facing platform_linking REST routes (JWT auth)."""
import logging
from typing import Annotated
from autogpt_libs import auth
from fastapi import APIRouter, HTTPException, Path, Security
from backend.data.db_accessors import platform_linking_db
from backend.platform_linking.models import (
ConfirmLinkResponse,
ConfirmUserLinkResponse,
DeleteLinkResponse,
LinkTokenInfoResponse,
PlatformLinkInfo,
PlatformUserLinkInfo,
)
from backend.util.exceptions import (
LinkAlreadyExistsError,
LinkFlowMismatchError,
LinkTokenExpiredError,
NotAuthorizedError,
NotFoundError,
)
logger = logging.getLogger(__name__)
router = APIRouter()
TokenPath = Annotated[
str,
Path(max_length=64, pattern=r"^[A-Za-z0-9_-]+$"),
]
def _translate(exc: Exception) -> HTTPException:
if isinstance(exc, NotFoundError):
return HTTPException(status_code=404, detail=str(exc))
if isinstance(exc, NotAuthorizedError):
return HTTPException(status_code=403, detail=str(exc))
if isinstance(exc, LinkAlreadyExistsError):
return HTTPException(status_code=409, detail=str(exc))
if isinstance(exc, LinkTokenExpiredError):
return HTTPException(status_code=410, detail=str(exc))
if isinstance(exc, LinkFlowMismatchError):
return HTTPException(status_code=400, detail=str(exc))
return HTTPException(status_code=500, detail="Internal error.")
@router.get(
"/tokens/{token}/info",
response_model=LinkTokenInfoResponse,
dependencies=[Security(auth.requires_user)],
summary="Get display info for a link token",
)
async def get_link_token_info_route(token: TokenPath) -> LinkTokenInfoResponse:
try:
return await platform_linking_db().get_link_token_info(token)
except (NotFoundError, LinkTokenExpiredError) as exc:
raise _translate(exc) from exc
@router.post(
"/tokens/{token}/confirm",
response_model=ConfirmLinkResponse,
dependencies=[Security(auth.requires_user)],
summary="Confirm a SERVER link token (user must be authenticated)",
)
async def confirm_link_token(
token: TokenPath,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> ConfirmLinkResponse:
try:
return await platform_linking_db().confirm_server_link(token, user_id)
except (
NotFoundError,
LinkFlowMismatchError,
LinkTokenExpiredError,
LinkAlreadyExistsError,
) as exc:
raise _translate(exc) from exc
@router.post(
"/user-tokens/{token}/confirm",
response_model=ConfirmUserLinkResponse,
dependencies=[Security(auth.requires_user)],
summary="Confirm a USER link token (user must be authenticated)",
)
async def confirm_user_link_token(
token: TokenPath,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> ConfirmUserLinkResponse:
try:
return await platform_linking_db().confirm_user_link(token, user_id)
except (
NotFoundError,
LinkFlowMismatchError,
LinkTokenExpiredError,
LinkAlreadyExistsError,
) as exc:
raise _translate(exc) from exc
@router.get(
"/links",
response_model=list[PlatformLinkInfo],
dependencies=[Security(auth.requires_user)],
summary="List all platform servers linked to the authenticated user",
)
async def list_my_links(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> list[PlatformLinkInfo]:
return await platform_linking_db().list_server_links(user_id)
@router.get(
"/user-links",
response_model=list[PlatformUserLinkInfo],
dependencies=[Security(auth.requires_user)],
summary="List all DM links for the authenticated user",
)
async def list_my_user_links(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> list[PlatformUserLinkInfo]:
return await platform_linking_db().list_user_links(user_id)
@router.delete(
"/links/{link_id}",
response_model=DeleteLinkResponse,
dependencies=[Security(auth.requires_user)],
summary="Unlink a platform server",
)
async def delete_link(
link_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> DeleteLinkResponse:
try:
return await platform_linking_db().delete_server_link(link_id, user_id)
except (NotFoundError, NotAuthorizedError) as exc:
raise _translate(exc) from exc
@router.delete(
"/user-links/{link_id}",
response_model=DeleteLinkResponse,
dependencies=[Security(auth.requires_user)],
summary="Unlink a DM / user link",
)
async def delete_user_link_route(
link_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> DeleteLinkResponse:
try:
return await platform_linking_db().delete_user_link(link_id, user_id)
except (NotFoundError, NotAuthorizedError) as exc:
raise _translate(exc) from exc

View File

@@ -1,264 +0,0 @@
"""Route tests: domain exceptions → HTTPException status codes."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
from backend.util.exceptions import (
LinkAlreadyExistsError,
LinkFlowMismatchError,
LinkTokenExpiredError,
NotAuthorizedError,
NotFoundError,
)
def _db_mock(**method_configs):
"""Return a mock of the accessor's return value with the given AsyncMocks."""
db = MagicMock()
for name, mock in method_configs.items():
setattr(db, name, mock)
return db
class TestTokenInfoRouteTranslation:
@pytest.mark.asyncio
async def test_not_found_maps_to_404(self):
from backend.api.features.platform_linking.routes import (
get_link_token_info_route,
)
db = _db_mock(
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as exc:
await get_link_token_info_route(token="abc")
assert exc.value.status_code == 404
@pytest.mark.asyncio
async def test_expired_maps_to_410(self):
from backend.api.features.platform_linking.routes import (
get_link_token_info_route,
)
db = _db_mock(
get_link_token_info=AsyncMock(side_effect=LinkTokenExpiredError("expired"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as exc:
await get_link_token_info_route(token="abc")
assert exc.value.status_code == 410
class TestConfirmLinkRouteTranslation:
@pytest.mark.asyncio
@pytest.mark.parametrize(
"exc,expected_status",
[
(NotFoundError("missing"), 404),
(LinkFlowMismatchError("wrong flow"), 400),
(LinkTokenExpiredError("expired"), 410),
(LinkAlreadyExistsError("already"), 409),
],
)
async def test_translation(self, exc: Exception, expected_status: int):
from backend.api.features.platform_linking.routes import confirm_link_token
db = _db_mock(confirm_server_link=AsyncMock(side_effect=exc))
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as ctx:
await confirm_link_token(token="abc", user_id="u1")
assert ctx.value.status_code == expected_status
class TestConfirmUserLinkRouteTranslation:
@pytest.mark.asyncio
@pytest.mark.parametrize(
"exc,expected_status",
[
(NotFoundError("missing"), 404),
(LinkFlowMismatchError("wrong flow"), 400),
(LinkTokenExpiredError("expired"), 410),
(LinkAlreadyExistsError("already"), 409),
],
)
async def test_translation(self, exc: Exception, expected_status: int):
from backend.api.features.platform_linking.routes import confirm_user_link_token
db = _db_mock(confirm_user_link=AsyncMock(side_effect=exc))
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as ctx:
await confirm_user_link_token(token="abc", user_id="u1")
assert ctx.value.status_code == expected_status
class TestDeleteLinkRouteTranslation:
@pytest.mark.asyncio
async def test_not_found_maps_to_404(self):
from backend.api.features.platform_linking.routes import delete_link
db = _db_mock(
delete_server_link=AsyncMock(side_effect=NotFoundError("missing"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as exc:
await delete_link(link_id="x", user_id="u1")
assert exc.value.status_code == 404
@pytest.mark.asyncio
async def test_not_owned_maps_to_403(self):
from backend.api.features.platform_linking.routes import delete_link
db = _db_mock(
delete_server_link=AsyncMock(side_effect=NotAuthorizedError("nope"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as exc:
await delete_link(link_id="x", user_id="u1")
assert exc.value.status_code == 403
class TestDeleteUserLinkRouteTranslation:
@pytest.mark.asyncio
async def test_not_found_maps_to_404(self):
from backend.api.features.platform_linking.routes import delete_user_link_route
db = _db_mock(delete_user_link=AsyncMock(side_effect=NotFoundError("missing")))
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as exc:
await delete_user_link_route(link_id="x", user_id="u1")
assert exc.value.status_code == 404
@pytest.mark.asyncio
async def test_not_owned_maps_to_403(self):
from backend.api.features.platform_linking.routes import delete_user_link_route
db = _db_mock(
delete_user_link=AsyncMock(side_effect=NotAuthorizedError("nope"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as exc:
await delete_user_link_route(link_id="x", user_id="u1")
assert exc.value.status_code == 403
# ── Adversarial: malformed token path params ──────────────────────────
class TestAdversarialTokenPath:
# TokenPath enforces `^[A-Za-z0-9_-]+$` + max_length=64.
@pytest.fixture
def client(self):
import fastapi
from autogpt_libs.auth import get_user_id, requires_user
from fastapi.testclient import TestClient
import backend.api.features.platform_linking.routes as routes_mod
app = fastapi.FastAPI()
app.dependency_overrides[requires_user] = lambda: None
app.dependency_overrides[get_user_id] = lambda: "caller-user"
app.include_router(routes_mod.router, prefix="/api/platform-linking")
return TestClient(app)
def test_rejects_token_with_special_chars(self, client):
response = client.get("/api/platform-linking/tokens/bad%24token/info")
assert response.status_code == 422
def test_rejects_token_with_path_traversal(self, client):
for probe in ("..%2F..", "foo..bar", "foo%2Fbar"):
response = client.get(f"/api/platform-linking/tokens/{probe}/info")
assert response.status_code in (
404,
422,
), f"path-traversal probe {probe!r} returned {response.status_code}"
def test_rejects_token_too_long(self, client):
long_token = "a" * 65
response = client.get(f"/api/platform-linking/tokens/{long_token}/info")
assert response.status_code == 422
def test_accepts_token_at_max_length(self, client):
token = "a" * 64
db = _db_mock(
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
response = client.get(f"/api/platform-linking/tokens/{token}/info")
assert response.status_code == 404
def test_accepts_urlsafe_b64_token_shape(self, client):
db = _db_mock(
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
response = client.get("/api/platform-linking/tokens/abc-_XYZ123-_abc/info")
assert response.status_code == 404
def test_confirm_rejects_malformed_token(self, client):
response = client.post("/api/platform-linking/tokens/bad%24token/confirm")
assert response.status_code == 422
class TestAdversarialDeleteLinkId:
"""DELETE link_id has no regex — ensure weird values are handled via
NotFoundError (no crash, no cross-user leak)."""
@pytest.fixture
def client(self):
import fastapi
from autogpt_libs.auth import get_user_id, requires_user
from fastapi.testclient import TestClient
import backend.api.features.platform_linking.routes as routes_mod
app = fastapi.FastAPI()
app.dependency_overrides[requires_user] = lambda: None
app.dependency_overrides[get_user_id] = lambda: "caller-user"
app.include_router(routes_mod.router, prefix="/api/platform-linking")
return TestClient(app)
def test_weird_link_id_returns_404(self, client):
db = _db_mock(
delete_server_link=AsyncMock(side_effect=NotFoundError("missing"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
for link_id in ("'; DROP TABLE links;--", "../../etc/passwd", ""):
response = client.delete(f"/api/platform-linking/links/{link_id}")
assert response.status_code in (404, 405)

View File

@@ -1,20 +0,0 @@
import pydantic
class PushSubscriptionKeys(pydantic.BaseModel):
p256dh: str = pydantic.Field(min_length=1, max_length=512)
auth: str = pydantic.Field(min_length=1, max_length=512)
class PushSubscribeRequest(pydantic.BaseModel):
endpoint: str = pydantic.Field(min_length=1, max_length=2048)
keys: PushSubscriptionKeys
user_agent: str | None = pydantic.Field(default=None, max_length=512)
class PushUnsubscribeRequest(pydantic.BaseModel):
endpoint: str = pydantic.Field(min_length=1, max_length=2048)
class VapidPublicKeyResponse(pydantic.BaseModel):
public_key: str

View File

@@ -1,64 +0,0 @@
from typing import Annotated
from autogpt_libs.auth import get_user_id, requires_user
from fastapi import APIRouter, HTTPException, Security
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
from backend.api.features.push.model import (
PushSubscribeRequest,
PushUnsubscribeRequest,
VapidPublicKeyResponse,
)
from backend.data.push_subscription import (
delete_push_subscription,
upsert_push_subscription,
validate_push_endpoint,
)
from backend.util.settings import Settings
router = APIRouter()
_settings = Settings()
@router.get(
"/vapid-key",
summary="Get VAPID public key for push subscription",
)
async def get_vapid_public_key() -> VapidPublicKeyResponse:
return VapidPublicKeyResponse(public_key=_settings.secrets.vapid_public_key)
@router.post(
"/subscribe",
summary="Register a push subscription for the current user",
status_code=HTTP_204_NO_CONTENT,
dependencies=[Security(requires_user)],
)
async def subscribe_push(
user_id: Annotated[str, Security(get_user_id)],
body: PushSubscribeRequest,
) -> None:
try:
await validate_push_endpoint(body.endpoint)
await upsert_push_subscription(
user_id=user_id,
endpoint=body.endpoint,
p256dh=body.keys.p256dh,
auth=body.keys.auth,
user_agent=body.user_agent,
)
except ValueError as e:
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=str(e))
@router.post(
"/unsubscribe",
summary="Remove a push subscription",
status_code=HTTP_204_NO_CONTENT,
dependencies=[Security(requires_user)],
)
async def unsubscribe_push(
user_id: Annotated[str, Security(get_user_id)],
body: PushUnsubscribeRequest,
) -> None:
await delete_push_subscription(user_id, body.endpoint)

View File

@@ -1,240 +0,0 @@
"""Tests for push notification routes."""
from unittest.mock import AsyncMock, MagicMock
import fastapi
import fastapi.testclient
import pytest
from backend.api.features.push.routes import router
app = fastapi.FastAPI()
app.include_router(router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def test_get_vapid_public_key(mocker):
mock_settings = MagicMock()
mock_settings.secrets.vapid_public_key = "test-vapid-public-key-base64url"
mocker.patch(
"backend.api.features.push.routes._settings",
mock_settings,
)
response = client.get("/vapid-key")
assert response.status_code == 200
data = response.json()
assert data["public_key"] == "test-vapid-public-key-base64url"
def test_get_vapid_public_key_empty(mocker):
mock_settings = MagicMock()
mock_settings.secrets.vapid_public_key = ""
mocker.patch(
"backend.api.features.push.routes._settings",
mock_settings,
)
response = client.get("/vapid-key")
assert response.status_code == 200
data = response.json()
assert data["public_key"] == ""
def test_subscribe_push(mocker, test_user_id):
mock_upsert = mocker.patch(
"backend.api.features.push.routes.upsert_push_subscription",
new_callable=AsyncMock,
)
response = client.post(
"/subscribe",
json={
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
"keys": {
"p256dh": "test-p256dh-key",
"auth": "test-auth-key",
},
"user_agent": "Mozilla/5.0 Test",
},
)
assert response.status_code == 204
mock_upsert.assert_awaited_once_with(
user_id=test_user_id,
endpoint="https://fcm.googleapis.com/fcm/send/abc123",
p256dh="test-p256dh-key",
auth="test-auth-key",
user_agent="Mozilla/5.0 Test",
)
def test_subscribe_push_without_user_agent(mocker, test_user_id):
mock_upsert = mocker.patch(
"backend.api.features.push.routes.upsert_push_subscription",
new_callable=AsyncMock,
)
response = client.post(
"/subscribe",
json={
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
"keys": {
"p256dh": "test-p256dh-key",
"auth": "test-auth-key",
},
},
)
assert response.status_code == 204
mock_upsert.assert_awaited_once_with(
user_id=test_user_id,
endpoint="https://fcm.googleapis.com/fcm/send/abc123",
p256dh="test-p256dh-key",
auth="test-auth-key",
user_agent=None,
)
def test_subscribe_push_missing_keys():
response = client.post(
"/subscribe",
json={
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
},
)
assert response.status_code == 422
def test_subscribe_push_missing_endpoint():
response = client.post(
"/subscribe",
json={
"keys": {
"p256dh": "test-p256dh-key",
"auth": "test-auth-key",
},
},
)
assert response.status_code == 422
def test_subscribe_push_rejects_empty_crypto_keys():
response = client.post(
"/subscribe",
json={
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
"keys": {"p256dh": "", "auth": ""},
},
)
assert response.status_code == 422
def test_subscribe_push_rejects_oversized_endpoint():
response = client.post(
"/subscribe",
json={
"endpoint": "https://fcm.googleapis.com/fcm/send/" + "x" * 3000,
"keys": {"p256dh": "k", "auth": "a"},
},
)
assert response.status_code == 422
def test_unsubscribe_push(mocker, test_user_id):
mock_delete = mocker.patch(
"backend.api.features.push.routes.delete_push_subscription",
new_callable=AsyncMock,
)
response = client.post(
"/unsubscribe",
json={
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
},
)
assert response.status_code == 204
mock_delete.assert_awaited_once_with(
test_user_id,
"https://fcm.googleapis.com/fcm/send/abc123",
)
def test_unsubscribe_push_missing_endpoint():
response = client.post(
"/unsubscribe",
json={},
)
assert response.status_code == 422
@pytest.mark.parametrize(
"untrusted_endpoint",
[
"https://localhost/evil",
"https://127.0.0.1/evil",
"https://169.254.169.254/latest/meta-data/",
"https://internal-service.local/api",
"https://attacker.example.com/push",
"http://fcm.googleapis.com/fcm/send/abc",
"file:///etc/passwd",
],
)
def test_subscribe_push_rejects_untrusted_endpoints(mocker, untrusted_endpoint):
mock_upsert = mocker.patch(
"backend.api.features.push.routes.upsert_push_subscription",
new_callable=AsyncMock,
)
response = client.post(
"/subscribe",
json={
"endpoint": untrusted_endpoint,
"keys": {
"p256dh": "test-p256dh-key",
"auth": "test-auth-key",
},
},
)
assert response.status_code == 400
mock_upsert.assert_not_awaited()
def test_subscribe_push_surfaces_cap_as_400(mocker):
mocker.patch(
"backend.api.features.push.routes.upsert_push_subscription",
new_callable=AsyncMock,
side_effect=ValueError("Subscription limit of 20 per user reached"),
)
response = client.post(
"/subscribe",
json={
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
"keys": {
"p256dh": "test-p256dh-key",
"auth": "test-auth-key",
},
},
)
assert response.status_code == 400
assert "Subscription limit" in response.json()["detail"]

View File

@@ -490,9 +490,6 @@ async def get_store_creators(
# Build where clause with sanitized inputs
where = {}
# Only return creators with approved agents
where["num_agents"] = {"gt": 0}
if featured:
where["is_featured"] = featured

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from unittest.mock import AsyncMock
import prisma.enums
import prisma.errors
@@ -51,8 +50,8 @@ async def test_get_store_agents(mocker):
# Mock prisma calls
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
mock_store_agent.return_value.find_many = AsyncMock(return_value=mock_agents)
mock_store_agent.return_value.count = AsyncMock(return_value=1)
mock_store_agent.return_value.find_many = mocker.AsyncMock(return_value=mock_agents)
mock_store_agent.return_value.count = mocker.AsyncMock(return_value=1)
# Call function
result = await db.get_store_agents()
@@ -95,7 +94,7 @@ async def test_get_store_agent_details(mocker):
# Mock StoreAgent prisma call
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
mock_store_agent.return_value.find_first = AsyncMock(return_value=mock_agent)
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
# Call function
result = await db.get_store_agent_details("creator", "test-agent")
@@ -134,7 +133,7 @@ async def test_get_store_creator(mocker):
# Mock prisma call
mock_creator = mocker.patch("prisma.models.Creator.prisma")
mock_creator.return_value.find_unique = AsyncMock()
mock_creator.return_value.find_unique = mocker.AsyncMock()
# Configure the mock to return values that will pass validation
mock_creator.return_value.find_unique.return_value = mock_creator_data
@@ -190,7 +189,7 @@ async def test_create_store_submission(mocker):
notifyOnAgentApproved=True,
notifyOnAgentRejected=True,
timezone="Europe/Delft",
subscriptionTier=prisma.enums.SubscriptionTier.BASIC, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
subscriptionTier=prisma.enums.SubscriptionTier.FREE, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
)
mock_agent = prisma.models.AgentGraph(
id="agent-id",
@@ -237,23 +236,23 @@ async def test_create_store_submission(mocker):
# Mock prisma calls
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
mock_agent_graph.return_value.find_first = AsyncMock(return_value=mock_agent)
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
# Mock transaction context manager
mock_tx = mocker.MagicMock()
mocker.patch(
"backend.api.features.store.db.transaction",
return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_tx),
__aexit__=AsyncMock(return_value=False),
return_value=mocker.AsyncMock(
__aenter__=mocker.AsyncMock(return_value=mock_tx),
__aexit__=mocker.AsyncMock(return_value=False),
),
)
mock_sl = mocker.patch("prisma.models.StoreListing.prisma")
mock_sl.return_value.find_unique = AsyncMock(return_value=None)
mock_sl.return_value.find_unique = mocker.AsyncMock(return_value=None)
mock_slv = mocker.patch("prisma.models.StoreListingVersion.prisma")
mock_slv.return_value.create = AsyncMock(return_value=mock_version)
mock_slv.return_value.create = mocker.AsyncMock(return_value=mock_version)
# Call function
result = await db.create_store_submission(
@@ -293,8 +292,10 @@ async def test_update_profile(mocker):
# Mock prisma calls
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
mock_profile_db.return_value.find_first = AsyncMock(return_value=mock_profile)
mock_profile_db.return_value.update = AsyncMock(return_value=mock_profile)
mock_profile_db.return_value.find_first = mocker.AsyncMock(
return_value=mock_profile
)
mock_profile_db.return_value.update = mocker.AsyncMock(return_value=mock_profile)
# Test data
profile = Profile(
@@ -335,7 +336,9 @@ async def test_get_user_profile(mocker):
# Mock prisma calls
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
mock_profile_db.return_value.find_first = AsyncMock(return_value=mock_profile)
mock_profile_db.return_value.find_first = mocker.AsyncMock(
return_value=mock_profile
)
# Call function
result = await db.get_user_profile("user-id")
@@ -393,38 +396,3 @@ async def test_get_store_agents_search_category_array_injection():
# Verify the query executed without error
# Category should be parameterized, preventing SQL injection
assert isinstance(result.agents, list)
@pytest.mark.asyncio(loop_scope="session")
async def test_get_store_creators_only_returns_approved(mocker):
mock_creators = [
prisma.models.Creator(
name="Creator One",
username="creator1",
description="desc",
links=["link1"],
avatar_url="avatar.jpg",
num_agents=1,
agent_rating=4.5,
agent_runs=10,
top_categories=["test"],
is_featured=False,
)
]
mock_creator = mocker.patch("prisma.models.Creator.prisma")
mock_creator.return_value.find_many = AsyncMock(return_value=mock_creators)
mock_creator.return_value.count = AsyncMock(return_value=1)
result = await db.get_store_creators()
assert len(result.creators) == 1
assert result.creators[0].username == "creator1"
mock_creator.return_value.find_many.assert_called_once()
mock_creator.return_value.count.assert_called_once()
_, find_kwargs = mock_creator.return_value.find_many.call_args
_, count_kwargs = mock_creator.return_value.count.call_args
assert find_kwargs["where"]["num_agents"] == {"gt": 0}
assert count_kwargs["where"]["num_agents"] == {"gt": 0}

View File

@@ -26,11 +26,10 @@ from fastapi import (
)
from fastapi.concurrency import run_in_threadpool
from prisma.enums import SubscriptionTier
from pydantic import BaseModel, Field
from pydantic import BaseModel
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
from typing_extensions import Optional, TypedDict
from backend.api.features.workspace.routes import create_file_download_response
from backend.api.model import (
CreateAPIKeyRequest,
CreateAPIKeyResponse,
@@ -44,33 +43,23 @@ from backend.api.model import (
UploadFileResponse,
)
from backend.blocks import get_block, get_blocks
from backend.copilot.rate_limit import get_tier_multipliers
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.auth import api_key as api_key_db
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.credit import (
AutoTopUpConfig,
PendingChangeUnknown,
RefundRequest,
TransactionHistory,
UserCredit,
cancel_stripe_subscription,
create_subscription_checkout,
get_active_subscription_period_end,
get_auto_top_up,
get_pending_subscription_change,
get_proration_credit_cents,
get_subscription_price_id,
get_user_credit_model,
handle_subscription_payment_failure,
handle_subscription_payment_success,
modify_stripe_subscription_for_tier,
release_pending_subscription_schedule,
set_auto_top_up,
set_subscription_tier,
sync_subscription_from_stripe,
sync_subscription_schedule_from_stripe,
)
from backend.data.graph import GraphSettings
from backend.data.model import CredentialsMetaInput, UserOnboarding
@@ -100,7 +89,6 @@ from backend.data.user import (
update_user_notification_preference,
update_user_timezone,
)
from backend.data.workspace import get_workspace_file_by_id
from backend.executor import scheduler
from backend.executor import utils as execution_utils
from backend.integrations.webhooks.graph_lifecycle_hooks import (
@@ -702,51 +690,19 @@ async def get_user_auto_top_up(
class SubscriptionTierRequest(BaseModel):
tier: Literal["NO_TIER", "BASIC", "PRO", "MAX", "BUSINESS"]
tier: Literal["FREE", "PRO", "BUSINESS"]
success_url: str = ""
cancel_url: str = ""
class SubscriptionCheckoutResponse(BaseModel):
url: str
class SubscriptionStatusResponse(BaseModel):
tier: Literal["NO_TIER", "BASIC", "PRO", "MAX", "BUSINESS", "ENTERPRISE"]
tier: str
monthly_cost: int # amount in cents (Stripe convention)
tier_costs: dict[str, int] # tier name -> amount in cents
tier_multipliers: dict[str, float] = Field(
default_factory=dict,
description=(
"Tier → rate-limit multiplier. Covers the same tiers listed in"
" ``tier_costs`` so the frontend can render rate-limit badges"
" relative to the lowest visible tier without knowing backend"
" defaults."
),
)
proration_credit_cents: int # unused portion of current sub to convert on upgrade
has_active_stripe_subscription: bool = Field(
default=False,
description=(
"True when the user has an active/trialing Stripe subscription. The"
" frontend uses this to branch upgrade UX: modify-in-place + saved-card"
" auto-charge when True, redirect to Stripe Checkout when False."
),
)
current_period_end: Optional[int] = Field(
default=None,
description=(
"Unix timestamp of the active subscription's current_period_end. Used"
" to show the date Stripe will issue the next invoice (with prorated"
" upgrade charges, if any). None when no active sub."
),
)
pending_tier: Optional[Literal["NO_TIER", "BASIC", "PRO", "MAX", "BUSINESS"]] = None
pending_tier_effective_at: Optional[datetime] = None
url: str = Field(
default="",
description=(
"Populated only when POST /credits/subscription starts a Stripe Checkout"
" Session (BASIC → paid upgrade). Empty string in all other branches —"
" the client redirects to this URL when non-empty."
),
)
def _validate_checkout_redirect_url(url: str) -> bool:
@@ -756,13 +712,15 @@ def _validate_checkout_redirect_url(url: str) -> bool:
success_url/cancel_url that Stripe will redirect users to after checkout.
Pre-parse rejection rules (applied before urlparse):
- URLs containing ``@`` can exploit ``user:pass@host`` authority tricks.
- Backslashes (``\\``) are normalised differently across parsers/browsers.
- Control characters (U+0000U+001F) are not valid in URLs and may confuse
some URL-parsing implementations.
"""
# Reject characters that can confuse URL parsers before any parsing.
if "\\" in url:
return False
for bad_char in ("@", "\\"):
if bad_char in url:
return False
if any(ord(c) < 0x20 for c in url):
return False
@@ -777,11 +735,6 @@ def _validate_checkout_redirect_url(url: str) -> bool:
return False
if parsed.scheme not in ("http", "https"):
return False
# Reject ``user:pass@host`` authority tricks — ``@`` in the netloc component
# can trick browsers into connecting to a different host than displayed.
# ``@`` in query/fragment is harmless and must be allowed.
if "@" in parsed.netloc:
return False
return (
parsed.scheme == allowed_parsed.scheme
and parsed.netloc == allowed_parsed.netloc
@@ -822,89 +775,35 @@ async def get_subscription_status(
user_id: Annotated[str, Security(get_user_id)],
) -> SubscriptionStatusResponse:
user = await get_user_by_id(user_id)
tier = user.subscription_tier or SubscriptionTier.NO_TIER
tier = user.subscription_tier or SubscriptionTier.FREE
# Tiers that *can* have a Stripe price configured (and therefore appear
# in the tier picker if the LD flag exposes a price-id). NO_TIER is not
# priceable — it's the implicit "no active subscription" state.
priceable_tiers = [
SubscriptionTier.BASIC,
SubscriptionTier.PRO,
SubscriptionTier.MAX,
SubscriptionTier.BUSINESS,
]
paid_tiers = [SubscriptionTier.PRO, SubscriptionTier.BUSINESS]
price_ids = await asyncio.gather(
*[get_subscription_price_id(t) for t in priceable_tiers]
*[get_subscription_price_id(t) for t in paid_tiers]
)
tier_costs: dict[str, int] = {
SubscriptionTier.FREE.value: 0,
SubscriptionTier.ENTERPRISE.value: 0,
}
async def _cost(pid: str | None) -> int:
return (await _get_stripe_price_amount(pid) or 0) if pid else 0
costs = await asyncio.gather(*[_cost(pid) for pid in price_ids])
for t, cost in zip(paid_tiers, costs):
tier_costs[t.value] = cost
tier_costs: dict[str, int] = {}
for t, pid, cost in zip(priceable_tiers, price_ids, costs):
if pid:
tier_costs[t.value] = cost
# Expose the effective rate-limit multipliers alongside prices so the
# frontend can render "Nx rate limits" relative to the lowest visible
# tier without hard-coding backend defaults. Only emit entries for tiers
# that land in ``tier_costs`` — rows hidden at the price layer must stay
# hidden in the multiplier layer too.
multipliers = await get_tier_multipliers()
tier_multipliers: dict[str, float] = {
t.value: multipliers.get(t, 1.0)
for t in priceable_tiers
if t.value in tier_costs
}
current_monthly_cost = tier_costs.get(tier.value, 0)
proration_credit, current_period_end = await asyncio.gather(
get_proration_credit_cents(user_id, current_monthly_cost),
get_active_subscription_period_end(user_id),
)
try:
pending = await get_pending_subscription_change(user_id)
except (stripe.StripeError, PendingChangeUnknown):
# Swallow Stripe-side failures (rate limits, transient network) AND
# PendingChangeUnknown (LaunchDarkly price-id lookup failed). Both
# propagate past the cache so the next request retries fresh instead
# of serving a stale None for the TTL window. Let real bugs (KeyError,
# AttributeError, etc.) propagate so they surface in Sentry.
logger.exception(
"get_subscription_status: failed to resolve pending change for user %s",
user_id,
)
pending = None
response = SubscriptionStatusResponse(
return SubscriptionStatusResponse(
tier=tier.value,
monthly_cost=current_monthly_cost,
monthly_cost=tier_costs.get(tier.value, 0),
tier_costs=tier_costs,
tier_multipliers=tier_multipliers,
proration_credit_cents=proration_credit,
has_active_stripe_subscription=current_period_end is not None,
current_period_end=current_period_end,
)
if pending is not None:
pending_tier_enum, pending_effective_at = pending
if pending_tier_enum in (
SubscriptionTier.NO_TIER,
SubscriptionTier.BASIC,
SubscriptionTier.PRO,
SubscriptionTier.MAX,
SubscriptionTier.BUSINESS,
):
response.pending_tier = pending_tier_enum.value
response.pending_tier_effective_at = pending_effective_at
return response
@v1_router.post(
path="/credits/subscription",
summary="Update subscription tier or start a Stripe Checkout session",
summary="Start a Stripe Checkout session to upgrade subscription tier",
operation_id="updateSubscriptionTier",
tags=["credits"],
dependencies=[Security(requires_user)],
@@ -912,59 +811,31 @@ async def get_subscription_status(
async def update_subscription_tier(
request: SubscriptionTierRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> SubscriptionStatusResponse:
# Pydantic validates tier is one of BASIC/PRO/MAX/BUSINESS via Literal type.
) -> SubscriptionCheckoutResponse:
# Pydantic validates tier is one of FREE/PRO/BUSINESS via Literal type.
tier = SubscriptionTier(request.tier)
# ENTERPRISE tier is admin-managed — block self-service changes from ENTERPRISE users.
user = await get_user_by_id(user_id)
if (
user.subscription_tier or SubscriptionTier.NO_TIER
) == SubscriptionTier.ENTERPRISE:
if (user.subscription_tier or SubscriptionTier.FREE) == SubscriptionTier.ENTERPRISE:
raise HTTPException(
status_code=403,
detail="ENTERPRISE subscription changes must be managed by an administrator",
)
# Same-tier request = "stay on my current tier" = cancel any pending
# scheduled change (paid→paid downgrade or paid→BASIC cancel). This is the
# collapsed behaviour that replaces the old /credits/subscription/cancel-pending
# route. Safe when no pending change exists: release_pending_subscription_schedule
# returns False and we simply return the current status.
if (user.subscription_tier or SubscriptionTier.NO_TIER) == tier:
try:
await release_pending_subscription_schedule(user_id)
except stripe.StripeError as e:
logger.exception(
"Stripe error releasing pending subscription change for user %s: %s",
user_id,
e,
)
raise HTTPException(
status_code=502,
detail=(
"Unable to cancel the pending subscription change right now. "
"Please try again or contact support."
),
)
return await get_subscription_status(user_id)
payment_enabled = await is_feature_enabled(
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
)
target_price_id = await get_subscription_price_id(tier)
# Cancel: target NO_TIER. Schedule Stripe cancellation at period end;
# cancel_at_period_end=True lets the webhook flip the DB tier. No active
# sub (admin-granted or never-paid) or payment disabled → DB flip.
# NO_TIER is never priceable, so this branch always fires for cancel
# requests regardless of LD config.
if tier == SubscriptionTier.NO_TIER:
# Downgrade to FREE: cancel active Stripe subscription, then update the DB tier.
if tier == SubscriptionTier.FREE:
if payment_enabled:
try:
had_subscription = await cancel_stripe_subscription(user_id)
await cancel_stripe_subscription(user_id)
except stripe.StripeError as e:
# Log full Stripe error server-side but return a generic message
# to the client — raw Stripe errors can leak customer/sub IDs and
# infrastructure config details.
logger.exception(
"Stripe error cancelling subscription for user %s: %s",
user_id,
@@ -977,45 +848,24 @@ async def update_subscription_tier(
"Please try again or contact support."
),
)
if not had_subscription:
await set_subscription_tier(user_id, tier)
return await get_subscription_status(user_id)
await set_subscription_tier(user_id, tier)
return await get_subscription_status(user_id)
return SubscriptionCheckoutResponse(url="")
# Beta users (payment not enabled) → update tier directly without Stripe.
if not payment_enabled:
raise HTTPException(
status_code=422,
detail=f"Subscription not available for tier {tier.value}",
)
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
# Target has no LD price — not provisionable (matches the GET hiding).
if target_price_id is None:
raise HTTPException(
status_code=422,
detail=f"Subscription not available for tier {tier.value}",
)
# No-op short-circuit: if the user is already on the requested paid tier,
# do NOT create a new Checkout Session. Without this guard, a duplicate
# request (double-click, retried POST, stale page) creates a second
# subscription for the same price; the user would be charged for both
# until `_cleanup_stale_subscriptions` runs from the resulting webhook —
# which only fires after the second charge has cleared.
if (user.subscription_tier or SubscriptionTier.FREE) == tier:
return SubscriptionCheckoutResponse(url="")
# Modify in place if there's a sub; else fall through to Checkout below.
try:
modified = await modify_stripe_subscription_for_tier(user_id, tier)
if modified:
return await get_subscription_status(user_id)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except stripe.StripeError as e:
logger.exception(
"Stripe error modifying subscription for user %s: %s", user_id, e
)
raise HTTPException(
status_code=502,
detail=(
"Unable to update your subscription right now. "
"Please try again or contact support."
),
)
# No active Stripe subscription → create Stripe Checkout Session.
# Paid upgrade → create Stripe Checkout Session.
if not request.success_url or not request.cancel_url:
raise HTTPException(
status_code=422,
@@ -1024,24 +874,6 @@ async def update_subscription_tier(
# Open-redirect protection: both URLs must point to the configured frontend
# origin, otherwise an attacker could use our Stripe integration as a
# redirector to arbitrary phishing sites.
#
# Fail early with a clear 503 if the server is misconfigured (neither
# frontend_base_url nor platform_base_url set), so operators get an
# actionable error instead of the misleading "must match the platform
# frontend origin" 422 that _validate_checkout_redirect_url would otherwise
# produce when `allowed` is empty.
if not (settings.config.frontend_base_url or settings.config.platform_base_url):
logger.error(
"update_subscription_tier: neither frontend_base_url nor "
"platform_base_url is configured; cannot validate checkout redirect URLs"
)
raise HTTPException(
status_code=503,
detail=(
"Payment redirect URLs cannot be validated: "
"frontend_base_url or platform_base_url must be set on the server."
),
)
if not _validate_checkout_redirect_url(
request.success_url
) or not _validate_checkout_redirect_url(request.cancel_url):
@@ -1070,9 +902,7 @@ async def update_subscription_tier(
),
)
status = await get_subscription_status(user_id)
status.url = url
return status
return SubscriptionCheckoutResponse(url=url)
@v1_router.post(
@@ -1137,24 +967,6 @@ async def stripe_webhook(request: Request):
):
await sync_subscription_from_stripe(data_object)
# `subscription_schedule.updated` is deliberately omitted: our own
# `SubscriptionSchedule.create` + `.modify` calls in
# `_schedule_downgrade_at_period_end` would fire that event right back at us
# and loop redundant traffic through this handler. We only care about state
# transitions (released / completed); phase advance to the new price is
# already covered by `customer.subscription.updated`.
if event_type in (
"subscription_schedule.released",
"subscription_schedule.completed",
):
await sync_subscription_schedule_from_stripe(data_object)
if event_type == "invoice.payment_succeeded":
await handle_subscription_payment_success(data_object)
if event_type == "invoice.payment_failed":
await handle_subscription_payment_failure(data_object)
# `handle_dispute` and `deduct_credits` expect Stripe SDK typed objects
# (Dispute/Refund). The Stripe webhook payload's `data.object` is a
# StripeObject (a dict subclass) carrying that runtime shape, so we cast
@@ -1749,10 +1561,6 @@ async def enable_execution_sharing(
# Generate a unique share token
share_token = str(uuid.uuid4())
# Remove stale allowlist records before updating the token — prevents a
# window where old records + new token could coexist.
await execution_db.delete_shared_execution_files(execution_id=graph_exec_id)
# Update the execution with share info
await execution_db.update_graph_execution_share_status(
execution_id=graph_exec_id,
@@ -1762,14 +1570,6 @@ async def enable_execution_sharing(
shared_at=datetime.now(timezone.utc),
)
# Create allowlist of workspace files referenced in outputs
await execution_db.create_shared_execution_files(
execution_id=graph_exec_id,
share_token=share_token,
user_id=user_id,
outputs=execution.outputs,
)
# Return the share URL
frontend_url = settings.config.frontend_base_url or "http://localhost:3000"
share_url = f"{frontend_url}/share/{share_token}"
@@ -1795,9 +1595,6 @@ async def disable_execution_sharing(
if not execution:
raise HTTPException(status_code=404, detail="Execution not found")
# Remove shared file allowlist records
await execution_db.delete_shared_execution_files(execution_id=graph_exec_id)
# Remove share info
await execution_db.update_graph_execution_share_status(
execution_id=graph_exec_id,
@@ -1823,43 +1620,6 @@ async def get_shared_execution(
return execution
@v1_router.get(
"/public/shared/{share_token}/files/{file_id}/download",
summary="Download a file from a shared execution",
operation_id="download_shared_file",
tags=["graphs"],
)
async def download_shared_file(
share_token: Annotated[
str,
Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"),
],
file_id: Annotated[
str,
Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"),
],
) -> Response:
"""Download a workspace file from a shared execution (no auth required).
Validates that the file was explicitly exposed when sharing was enabled.
Returns a uniform 404 for all failure modes to prevent enumeration attacks.
"""
# Single-query validation against the allowlist
execution_id = await execution_db.get_shared_execution_file(
share_token=share_token, file_id=file_id
)
if not execution_id:
raise HTTPException(status_code=404, detail="Not found")
# Look up the actual file (no workspace scoping needed — the allowlist
# already validated that this file belongs to the shared execution)
file = await get_workspace_file_by_id(file_id)
if not file:
raise HTTPException(status_code=404, detail="Not found")
return await create_file_download_response(file, inline=True)
########################################################
##################### Schedules ########################
########################################################

View File

@@ -1,157 +0,0 @@
"""Tests for the public shared file download endpoint."""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from starlette.responses import Response
from backend.api.features.v1 import v1_router
from backend.data.workspace import WorkspaceFile
app = FastAPI()
app.include_router(v1_router, prefix="/api")
VALID_TOKEN = "550e8400-e29b-41d4-a716-446655440000"
VALID_FILE_ID = "6ba7b810-9dad-11d1-80b4-00c04fd430c8"
def _make_workspace_file(**overrides) -> WorkspaceFile:
defaults = {
"id": VALID_FILE_ID,
"workspace_id": "ws-001",
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
"updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
"name": "image.png",
"path": "/image.png",
"storage_path": "local://uploads/image.png",
"mime_type": "image/png",
"size_bytes": 4,
"checksum": None,
"is_deleted": False,
"deleted_at": None,
"metadata": {},
}
defaults.update(overrides)
return WorkspaceFile(**defaults)
def _mock_download_response(**kwargs):
"""Return an AsyncMock that resolves to a Response with inline disposition."""
async def _handler(file, *, inline=False):
return Response(
content=b"\x89PNG",
media_type="image/png",
headers={
"Content-Disposition": (
'inline; filename="image.png"'
if inline
else 'attachment; filename="image.png"'
),
"Content-Length": "4",
},
)
return _handler
class TestDownloadSharedFile:
"""Tests for GET /api/public/shared/{token}/files/{id}/download."""
@pytest.fixture(autouse=True)
def _client(self):
self.client = TestClient(app, raise_server_exceptions=False)
def test_valid_token_and_file_returns_inline_content(self):
with (
patch(
"backend.api.features.v1.execution_db.get_shared_execution_file",
new_callable=AsyncMock,
return_value="exec-123",
),
patch(
"backend.api.features.v1.get_workspace_file_by_id",
new_callable=AsyncMock,
return_value=_make_workspace_file(),
),
patch(
"backend.api.features.v1.create_file_download_response",
side_effect=_mock_download_response(),
),
):
response = self.client.get(
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
)
assert response.status_code == 200
assert response.content == b"\x89PNG"
assert "inline" in response.headers["Content-Disposition"]
def test_invalid_token_format_returns_422(self):
response = self.client.get(
f"/api/public/shared/not-a-uuid/files/{VALID_FILE_ID}/download"
)
assert response.status_code == 422
def test_token_not_in_allowlist_returns_404(self):
with patch(
"backend.api.features.v1.execution_db.get_shared_execution_file",
new_callable=AsyncMock,
return_value=None,
):
response = self.client.get(
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
)
assert response.status_code == 404
def test_file_missing_from_workspace_returns_404(self):
with (
patch(
"backend.api.features.v1.execution_db.get_shared_execution_file",
new_callable=AsyncMock,
return_value="exec-123",
),
patch(
"backend.api.features.v1.get_workspace_file_by_id",
new_callable=AsyncMock,
return_value=None,
),
):
response = self.client.get(
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
)
assert response.status_code == 404
def test_uniform_404_prevents_enumeration(self):
"""Both failure modes produce identical 404 — no information leak."""
with patch(
"backend.api.features.v1.execution_db.get_shared_execution_file",
new_callable=AsyncMock,
return_value=None,
):
resp_no_allow = self.client.get(
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
)
with (
patch(
"backend.api.features.v1.execution_db.get_shared_execution_file",
new_callable=AsyncMock,
return_value="exec-123",
),
patch(
"backend.api.features.v1.get_workspace_file_by_id",
new_callable=AsyncMock,
return_value=None,
),
):
resp_no_file = self.client.get(
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
)
assert resp_no_allow.status_code == 404
assert resp_no_file.status_code == 404
assert resp_no_allow.json() == resp_no_file.json()

View File

@@ -29,9 +29,7 @@ from backend.util.workspace import WorkspaceManager
from backend.util.workspace_storage import get_workspace_storage
def _sanitize_filename_for_header(
filename: str, disposition: str = "attachment"
) -> str:
def _sanitize_filename_for_header(filename: str) -> str:
"""
Sanitize filename for Content-Disposition header to prevent header injection.
@@ -46,11 +44,11 @@ def _sanitize_filename_for_header(
# Check if filename has non-ASCII characters
try:
sanitized.encode("ascii")
return f'{disposition}; filename="{sanitized}"'
return f'attachment; filename="{sanitized}"'
except UnicodeEncodeError:
# Use RFC5987 encoding for UTF-8 filenames
encoded = quote(sanitized, safe="")
return f"{disposition}; filename*=UTF-8''{encoded}"
return f"attachment; filename*=UTF-8''{encoded}"
logger = logging.getLogger(__name__)
@@ -60,26 +58,19 @@ router = fastapi.APIRouter(
)
def _create_streaming_response(
content: bytes, file: WorkspaceFile, *, inline: bool = False
) -> Response:
def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
"""Create a streaming response for file content."""
disposition = _sanitize_filename_for_header(
file.name, disposition="inline" if inline else "attachment"
)
return Response(
content=content,
media_type=file.mime_type,
headers={
"Content-Disposition": disposition,
"Content-Disposition": _sanitize_filename_for_header(file.name),
"Content-Length": str(len(content)),
},
)
async def create_file_download_response(
file: WorkspaceFile, *, inline: bool = False
) -> Response:
async def _create_file_download_response(file: WorkspaceFile) -> Response:
"""
Create a download response for a workspace file.
@@ -91,7 +82,7 @@ async def create_file_download_response(
# For local storage, stream the file directly
if file.storage_path.startswith("local://"):
content = await storage.retrieve(file.storage_path)
return _create_streaming_response(content, file, inline=inline)
return _create_streaming_response(content, file)
# For GCS, try to redirect to signed URL, fall back to streaming
try:
@@ -99,7 +90,7 @@ async def create_file_download_response(
# If we got back an API path (fallback), stream directly instead
if url.startswith("/api/"):
content = await storage.retrieve(file.storage_path)
return _create_streaming_response(content, file, inline=inline)
return _create_streaming_response(content, file)
return fastapi.responses.RedirectResponse(url=url, status_code=302)
except Exception as e:
# Log the signed URL failure with context
@@ -111,7 +102,7 @@ async def create_file_download_response(
# Fall back to streaming directly from GCS
try:
content = await storage.retrieve(file.storage_path)
return _create_streaming_response(content, file, inline=inline)
return _create_streaming_response(content, file)
except Exception as fallback_error:
logger.error(
f"Fallback streaming also failed for file {file.id} "
@@ -178,7 +169,7 @@ async def download_file(
if file is None:
raise fastapi.HTTPException(status_code=404, detail="File not found")
return await create_file_download_response(file)
return await _create_file_download_response(file)
@router.delete(

View File

@@ -600,221 +600,3 @@ def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace):
mock_instance.list_files.assert_called_once_with(
limit=11, offset=50, include_all_sessions=True
)
# -- _sanitize_filename_for_header tests --
class TestSanitizeFilenameForHeader:
def test_simple_ascii_attachment(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
assert _sanitize_filename_for_header("report.pdf") == (
'attachment; filename="report.pdf"'
)
def test_inline_disposition(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
assert _sanitize_filename_for_header("image.png", disposition="inline") == (
'inline; filename="image.png"'
)
def test_strips_cr_lf_null(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
result = _sanitize_filename_for_header("a\rb\nc\x00d.txt")
assert "\r" not in result
assert "\n" not in result
assert "\x00" not in result
assert 'filename="abcd.txt"' in result
def test_escapes_quotes(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
result = _sanitize_filename_for_header('file"name.txt')
assert 'filename="file\\"name.txt"' in result
def test_header_injection_blocked(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
result = _sanitize_filename_for_header("evil.txt\r\nX-Injected: true")
# CR/LF stripped — the remaining text is safely inside the quoted value
assert "\r" not in result
assert "\n" not in result
assert result == 'attachment; filename="evil.txtX-Injected: true"'
def test_unicode_uses_rfc5987(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
result = _sanitize_filename_for_header("日本語.pdf")
assert "filename*=UTF-8''" in result
assert "attachment" in result
def test_unicode_inline(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
result = _sanitize_filename_for_header("图片.png", disposition="inline")
assert result.startswith("inline; filename*=UTF-8''")
def test_empty_filename(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
result = _sanitize_filename_for_header("")
assert result == 'attachment; filename=""'
# -- _create_streaming_response tests --
class TestCreateStreamingResponse:
def test_attachment_disposition_by_default(self):
from backend.api.features.workspace.routes import _create_streaming_response
file = _make_file(name="data.bin", mime_type="application/octet-stream")
response = _create_streaming_response(b"binary-data", file)
assert (
response.headers["Content-Disposition"] == 'attachment; filename="data.bin"'
)
assert response.headers["Content-Type"] == "application/octet-stream"
assert response.headers["Content-Length"] == "11"
assert response.body == b"binary-data"
def test_inline_disposition(self):
from backend.api.features.workspace.routes import _create_streaming_response
file = _make_file(name="photo.png", mime_type="image/png")
response = _create_streaming_response(b"\x89PNG", file, inline=True)
assert response.headers["Content-Disposition"] == 'inline; filename="photo.png"'
assert response.headers["Content-Type"] == "image/png"
def test_inline_sanitizes_filename(self):
from backend.api.features.workspace.routes import _create_streaming_response
file = _make_file(name='evil"\r\n.txt', mime_type="text/plain")
response = _create_streaming_response(b"data", file, inline=True)
assert "\r" not in response.headers["Content-Disposition"]
assert "\n" not in response.headers["Content-Disposition"]
assert "inline" in response.headers["Content-Disposition"]
def test_content_length_matches_body(self):
from backend.api.features.workspace.routes import _create_streaming_response
content = b"x" * 1000
file = _make_file(name="big.bin", mime_type="application/octet-stream")
response = _create_streaming_response(content, file)
assert response.headers["Content-Length"] == "1000"
# -- create_file_download_response tests --
class TestCreateFileDownloadResponse:
@pytest.mark.asyncio
async def test_local_storage_returns_streaming_response(self, mocker):
from backend.api.features.workspace.routes import create_file_download_response
mock_storage = AsyncMock()
mock_storage.retrieve.return_value = b"file contents"
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage",
return_value=mock_storage,
)
file = _make_file(
storage_path="local://uploads/test.txt",
mime_type="text/plain",
)
response = await create_file_download_response(file)
assert response.status_code == 200
assert response.body == b"file contents"
assert "attachment" in response.headers["Content-Disposition"]
@pytest.mark.asyncio
async def test_local_storage_inline(self, mocker):
from backend.api.features.workspace.routes import create_file_download_response
mock_storage = AsyncMock()
mock_storage.retrieve.return_value = b"\x89PNG"
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage",
return_value=mock_storage,
)
file = _make_file(
storage_path="local://uploads/photo.png",
mime_type="image/png",
name="photo.png",
)
response = await create_file_download_response(file, inline=True)
assert "inline" in response.headers["Content-Disposition"]
@pytest.mark.asyncio
async def test_gcs_redirect(self, mocker):
from backend.api.features.workspace.routes import create_file_download_response
mock_storage = AsyncMock()
mock_storage.get_download_url.return_value = (
"https://storage.googleapis.com/signed-url"
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage",
return_value=mock_storage,
)
file = _make_file(storage_path="gcs://bucket/file.pdf")
response = await create_file_download_response(file)
assert response.status_code == 302
assert (
response.headers["location"] == "https://storage.googleapis.com/signed-url"
)
@pytest.mark.asyncio
async def test_gcs_api_fallback_streams_directly(self, mocker):
from backend.api.features.workspace.routes import create_file_download_response
mock_storage = AsyncMock()
mock_storage.get_download_url.return_value = "/api/fallback"
mock_storage.retrieve.return_value = b"fallback content"
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage",
return_value=mock_storage,
)
file = _make_file(storage_path="gcs://bucket/file.txt")
response = await create_file_download_response(file)
assert response.status_code == 200
assert response.body == b"fallback content"
@pytest.mark.asyncio
async def test_gcs_signed_url_failure_falls_back_to_streaming(self, mocker):
from backend.api.features.workspace.routes import create_file_download_response
mock_storage = AsyncMock()
mock_storage.get_download_url.side_effect = RuntimeError("GCS error")
mock_storage.retrieve.return_value = b"streamed"
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage",
return_value=mock_storage,
)
file = _make_file(storage_path="gcs://bucket/file.txt")
response = await create_file_download_response(file)
assert response.status_code == 200
assert response.body == b"streamed"
@pytest.mark.asyncio
async def test_gcs_total_failure_raises(self, mocker):
from backend.api.features.workspace.routes import create_file_download_response
mock_storage = AsyncMock()
mock_storage.get_download_url.side_effect = RuntimeError("GCS error")
mock_storage.retrieve.side_effect = RuntimeError("Also failed")
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage",
return_value=mock_storage,
)
file = _make_file(storage_path="gcs://bucket/file.txt")
with pytest.raises(RuntimeError, match="Also failed"):
await create_file_download_response(file)

View File

@@ -17,7 +17,6 @@ from fastapi.routing import APIRoute
from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.diagnostics_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.platform_cost_routes
import backend.api.features.admin.rate_limit_admin_routes
@@ -32,9 +31,7 @@ import backend.api.features.library.routes
import backend.api.features.mcp.routes as mcp_routes
import backend.api.features.oauth
import backend.api.features.otto.routes
import backend.api.features.platform_linking.routes
import backend.api.features.postmark.postmark
import backend.api.features.push.routes as push_routes
import backend.api.features.store.model
import backend.api.features.store.routes
import backend.api.features.v1
@@ -42,7 +39,6 @@ import backend.api.features.workspace.routes as workspace_routes
import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.redis_client
import backend.data.user
import backend.integrations.webhooks.utils
import backend.util.service
@@ -97,8 +93,6 @@ async def lifespan_context(app: fastapi.FastAPI):
verify_auth_settings()
await backend.data.db.connect()
# Eager connect to fail-fast if Redis is unreachable.
await backend.data.redis_client.get_redis_async()
# Configure thread pool for FastAPI sync operation performance
# CRITICAL: FastAPI automatically runs ALL sync functions in this thread pool:
@@ -150,18 +144,7 @@ async def lifespan_context(app: fastapi.FastAPI):
except Exception as e:
logger.warning(f"Error shutting down workspace storage: {e}")
# Each cleanup is wrapped so one failure doesn't block the rest. The
# Redis close in particular silences asyncio's "Unclosed ClusterNode"
# GC warning at interpreter shutdown.
try:
await backend.data.redis_client.disconnect_async()
except Exception:
logger.warning("redis_client.disconnect_async failed", exc_info=True)
try:
await backend.data.db.disconnect()
except Exception:
logger.warning("db.disconnect failed", exc_info=True)
await backend.data.db.disconnect()
def custom_generate_unique_id(route: APIRoute):
@@ -337,11 +320,6 @@ app.include_router(
tags=["v2", "admin"],
prefix="/api/credits",
)
app.include_router(
backend.api.features.admin.diagnostics_admin_routes.router,
tags=["v2", "admin"],
prefix="/api",
)
app.include_router(
backend.api.features.admin.execution_analytics_routes.router,
tags=["v2", "admin"],
@@ -394,16 +372,6 @@ app.include_router(
tags=["oauth"],
prefix="/api/oauth",
)
app.include_router(
push_routes.router,
tags=["push"],
prefix="/api/push",
)
app.include_router(
backend.api.features.platform_linking.routes.router,
tags=["platform-linking"],
prefix="/api/platform-linking",
)
app.mount("/external-api", external_api)

View File

@@ -1,3 +1,4 @@
import asyncio
import logging
from contextlib import asynccontextmanager
from typing import Protocol
@@ -16,12 +17,14 @@ from backend.api.model import (
WSSubscribeGraphExecutionsRequest,
)
from backend.api.utils.cors import build_cors_params
from backend.data import db, redis_client
from backend.data.execution import AsyncRedisExecutionEventBus
from backend.data.notification_bus import AsyncRedisNotificationEventBus
from backend.data.user import DEFAULT_USER_ID
from backend.monitoring.instrumentation import (
instrument_fastapi,
update_websocket_connections,
)
from backend.util.retry import continuous_retry
from backend.util.service import AppProcess
from backend.util.settings import AppEnvironment, Config, Settings
@@ -31,24 +34,10 @@ settings = Settings()
@asynccontextmanager
async def lifespan(app: FastAPI):
# Prisma is needed to resolve graph_id from graph_exec_id on subscribe.
await db.connect()
# Eager connect to fail-fast if Redis is unreachable.
await redis_client.get_redis_async()
try:
yield
finally:
# Each cleanup is wrapped so one failure doesn't block the rest. The
# Redis close silences asyncio's "Unclosed ClusterNode" GC warning at
# interpreter shutdown.
try:
await redis_client.disconnect_async()
except Exception:
logger.warning("redis_client.disconnect_async failed", exc_info=True)
try:
await db.disconnect()
except Exception:
logger.warning("db.disconnect failed", exc_info=True)
manager = get_connection_manager()
fut = asyncio.create_task(event_broadcaster(manager))
fut.add_done_callback(lambda _: logger.info("Event broadcaster stopped"))
yield
docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
@@ -72,6 +61,31 @@ def get_connection_manager():
return _connection_manager
@continuous_retry()
async def event_broadcaster(manager: ConnectionManager):
execution_bus = AsyncRedisExecutionEventBus()
notification_bus = AsyncRedisNotificationEventBus()
try:
async def execution_worker():
async for event in execution_bus.listen("*"):
await manager.send_execution_update(event)
async def notification_worker():
async for notification in notification_bus.listen("*"):
await manager.send_notification(
user_id=notification.user_id,
payload=notification.payload,
)
await asyncio.gather(execution_worker(), notification_worker())
finally:
# Ensure PubSub connections are closed on any exit to prevent leaks
await execution_bus.close()
await notification_bus.close()
async def authenticate_websocket(websocket: WebSocket) -> str:
if not settings.config.enable_auth:
return DEFAULT_USER_ID
@@ -283,21 +297,6 @@ async def websocket_router(
).model_dump_json()
)
continue
except ValueError as e:
logger.warning(
"Subscription rejected for user #%s on '%s': %s",
user_id,
message.method.value,
e,
)
await websocket.send_text(
WSMessage(
method=WSMethod.ERROR,
success=False,
error=str(e),
).model_dump_json()
)
continue
except Exception as e:
logger.error(
f"Error while handling '{message.method.value}' message "
@@ -322,13 +321,9 @@ async def websocket_router(
)
except WebSocketDisconnect:
manager.disconnect_socket(websocket, user_id=user_id)
logger.debug("WebSocket client disconnected")
except Exception:
logger.exception(f"Unexpected error in websocket_router for user #{user_id}")
finally:
# Always release subscription pumps + Redis connections, regardless of how
# the loop exited — otherwise non-WebSocketDisconnect failures leak both.
await manager.disconnect_socket(websocket, user_id=user_id)
update_websocket_connections(user_id, -1)

View File

@@ -44,12 +44,9 @@ def test_websocket_server_uses_cors_helper(mocker) -> None:
"backend.api.ws_api.build_cors_params", return_value=cors_params
)
with (
override_config(
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
),
override_config(settings, "app_env", AppEnvironment.LOCAL),
):
with override_config(
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
), override_config(settings, "app_env", AppEnvironment.LOCAL):
WebsocketServer().run()
build_cors.assert_called_once_with(
@@ -68,12 +65,9 @@ def test_websocket_server_uses_cors_helper(mocker) -> None:
def test_websocket_server_blocks_localhost_in_production(mocker) -> None:
mocker.patch("backend.api.ws_api.uvicorn.run")
with (
override_config(
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
),
override_config(settings, "app_env", AppEnvironment.PRODUCTION),
):
with override_config(
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
), override_config(settings, "app_env", AppEnvironment.PRODUCTION):
with pytest.raises(ValueError):
WebsocketServer().run()
@@ -296,232 +290,7 @@ async def test_handle_unsubscribe_missing_data(
message=message,
)
mock_manager.unsubscribe_graph_exec.assert_not_called()
mock_manager._unsubscribe.assert_not_called()
mock_websocket.send_text.assert_called_once()
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
# ---------- Per-graph subscribe branch ----------
@pytest.mark.asyncio
async def test_handle_subscribe_graph_execs_branch(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
"""The SUBSCRIBE_GRAPH_EXECS branch must route to subscribe_graph_execs,
not subscribe_graph_exec — regression guard for the aggregate channel."""
message = WSMessage(
method=WSMethod.SUBSCRIBE_GRAPH_EXECS,
data={"graph_id": "graph-abc"},
)
mock_manager.subscribe_graph_execs.return_value = (
"user-1|graph#graph-abc|executions"
)
await handle_subscribe(
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
mock_manager.subscribe_graph_execs.assert_called_once_with(
user_id="user-1",
graph_id="graph-abc",
websocket=mock_websocket,
)
mock_manager.subscribe_graph_exec.assert_not_called()
mock_websocket.send_text.assert_called_once()
assert (
'"method":"subscribe_graph_executions"'
in mock_websocket.send_text.call_args[0][0]
)
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
@pytest.mark.asyncio
async def test_handle_subscribe_rejects_unrelated_method(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
"""handle_subscribe must raise for methods that aren't SUBSCRIBE_*."""
import pytest as _pytest
message = WSMessage(
method=WSMethod.HEARTBEAT,
data={"graph_exec_id": "x"},
)
with _pytest.raises(ValueError):
await handle_subscribe(
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
# ---------- authenticate_websocket branches ----------
@pytest.mark.asyncio
async def test_authenticate_websocket_missing_token_closes_4001(mocker) -> None:
from backend.api.ws_api import authenticate_websocket
mocker.patch.object(settings.config, "enable_auth", True)
ws = AsyncMock(spec=WebSocket)
ws.query_params = {}
user_id = await authenticate_websocket(ws)
ws.close.assert_awaited_once()
assert ws.close.call_args.kwargs["code"] == 4001
assert user_id == ""
@pytest.mark.asyncio
async def test_authenticate_websocket_invalid_token_closes_4003(mocker) -> None:
from backend.api.ws_api import authenticate_websocket
mocker.patch.object(settings.config, "enable_auth", True)
mocker.patch(
"backend.api.ws_api.parse_jwt_token", side_effect=ValueError("bad token")
)
ws = AsyncMock(spec=WebSocket)
ws.query_params = {"token": "abc"}
user_id = await authenticate_websocket(ws)
ws.close.assert_awaited_once()
assert ws.close.call_args.kwargs["code"] == 4003
assert user_id == ""
@pytest.mark.asyncio
async def test_authenticate_websocket_missing_sub_closes_4002(mocker) -> None:
from backend.api.ws_api import authenticate_websocket
mocker.patch.object(settings.config, "enable_auth", True)
mocker.patch("backend.api.ws_api.parse_jwt_token", return_value={"not_sub": "x"})
ws = AsyncMock(spec=WebSocket)
ws.query_params = {"token": "abc"}
user_id = await authenticate_websocket(ws)
ws.close.assert_awaited_once()
assert ws.close.call_args.kwargs["code"] == 4002
assert user_id == ""
@pytest.mark.asyncio
async def test_authenticate_websocket_happy_path_returns_sub(mocker) -> None:
from backend.api.ws_api import authenticate_websocket
mocker.patch.object(settings.config, "enable_auth", True)
mocker.patch("backend.api.ws_api.parse_jwt_token", return_value={"sub": "user-X"})
ws = AsyncMock(spec=WebSocket)
ws.query_params = {"token": "abc"}
user_id = await authenticate_websocket(ws)
assert user_id == "user-X"
@pytest.mark.asyncio
async def test_authenticate_websocket_auth_disabled_returns_default(mocker) -> None:
from backend.api.ws_api import authenticate_websocket
mocker.patch.object(settings.config, "enable_auth", False)
ws = AsyncMock(spec=WebSocket)
ws.query_params = {}
user_id = await authenticate_websocket(ws)
assert user_id == DEFAULT_USER_ID
# ---------- get_connection_manager singleton ----------
def test_get_connection_manager_singleton() -> None:
"""Repeated calls must return the same ConnectionManager — the WS router
depends on a single process-wide subscription table."""
import backend.api.ws_api as ws_api
ws_api._connection_manager = None
a = ws_api.get_connection_manager()
b = ws_api.get_connection_manager()
assert a is b
assert isinstance(a, ConnectionManager)
# ---------- Lifespan: Prisma connect/disconnect ----------
@pytest.mark.asyncio
async def test_lifespan_connects_and_disconnects_prisma(mocker) -> None:
"""Lifespan must both connect() and disconnect() db — the subscribe path
resolves graph_id via Prisma so a missing connect() is the regression bug."""
from fastapi import FastAPI
from backend.api.ws_api import lifespan
mock_db = mocker.patch("backend.api.ws_api.db")
mock_db.connect = AsyncMock()
mock_db.disconnect = AsyncMock()
dummy_app = FastAPI()
async with lifespan(dummy_app):
mock_db.connect.assert_awaited_once()
mock_db.disconnect.assert_not_called()
mock_db.disconnect.assert_awaited_once()
@pytest.mark.asyncio
async def test_lifespan_still_disconnects_on_exception(mocker) -> None:
"""If the app raises inside the yield, Prisma must still disconnect."""
from fastapi import FastAPI
from backend.api.ws_api import lifespan
mock_db = mocker.patch("backend.api.ws_api.db")
mock_db.connect = AsyncMock()
mock_db.disconnect = AsyncMock()
dummy_app = FastAPI()
class _Boom(Exception):
pass
with pytest.raises(_Boom):
async with lifespan(dummy_app):
raise _Boom()
mock_db.disconnect.assert_awaited_once()
# ---------- Health endpoint ----------
def test_health_endpoint_returns_ok() -> None:
# TestClient triggers lifespan — stub it out so Prisma isn't hit.
from contextlib import asynccontextmanager
from fastapi.testclient import TestClient
import backend.api.ws_api as ws_api
@asynccontextmanager
async def _noop_lifespan(app):
yield
# Replace the app-level lifespan temporarily.
real_router_lifespan = ws_api.app.router.lifespan_context
ws_api.app.router.lifespan_context = _noop_lifespan
try:
with TestClient(ws_api.app) as client:
r = client.get("/")
assert r.status_code == 200
assert r.json() == {"status": "healthy"}
finally:
ws_api.app.router.lifespan_context = real_router_lifespan

View File

@@ -38,23 +38,19 @@ def main(**kwargs):
from backend.api.rest_api import AgentServer
from backend.api.ws_api import WebsocketServer
from backend.copilot.bot.app import CoPilotChatBridge
from backend.copilot.executor.manager import CoPilotExecutor
from backend.data.db_manager import DatabaseManager
from backend.executor import ExecutionManager, Scheduler
from backend.notifications import NotificationManager
from backend.platform_linking.manager import PlatformLinkingManager
run_processes(
DatabaseManager().set_log_level("warning"),
Scheduler(),
NotificationManager(),
PlatformLinkingManager(),
WebsocketServer(),
AgentServer(),
ExecutionManager(),
CoPilotExecutor(),
CoPilotChatBridge(),
**kwargs,
)

View File

@@ -25,7 +25,6 @@ from backend.data.model import (
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
is_credentials_field_name,
)
@@ -44,7 +43,7 @@ logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
from backend.data.model import ContributorDetails
from backend.data.model import ContributorDetails, NodeExecutionStats
from ..data.graph import Link
@@ -96,64 +95,27 @@ class BlockCategory(Enum):
class BlockCostType(str, Enum):
# RUN : cost_amount credits per run.
# BYTE : cost_amount credits per byte of input data.
# SECOND : cost_amount credits per cost_divisor walltime seconds.
# ITEMS : cost_amount credits per cost_divisor items (from stats).
# COST_USD : cost_amount credits per USD of stats.provider_cost.
# TOKENS : per-(model, provider) rate table lookup; see TOKEN_COST.
RUN = "run"
BYTE = "byte"
SECOND = "second"
ITEMS = "items"
COST_USD = "cost_usd"
TOKENS = "tokens"
@property
def is_dynamic(self) -> bool:
"""Real charge is computed post-flight from stats.
Dynamic types (SECOND/ITEMS/COST_USD/TOKENS) return 0 pre-flight and
settle against stats via charge_reconciled_usage once the block runs.
"""
return self in _DYNAMIC_COST_TYPES
_DYNAMIC_COST_TYPES: frozenset[BlockCostType] = frozenset(
{
BlockCostType.SECOND,
BlockCostType.ITEMS,
BlockCostType.COST_USD,
BlockCostType.TOKENS,
}
)
RUN = "run" # cost X credits per run
BYTE = "byte" # cost X credits per byte
SECOND = "second" # cost X credits per second
class BlockCost(BaseModel):
cost_amount: int
cost_filter: BlockInput
cost_type: BlockCostType
# cost_divisor: interpret cost_amount as "credits per cost_divisor units".
# Only meaningful for SECOND / ITEMS. TOKENS routes through TOKEN_COST
# rate tables (per-model input/output/cache pricing) and ignores
# cost_divisor entirely. Defaults to 1 so existing RUN/BYTE entries stay
# point-wise. Example: cost_amount=1, cost_divisor=10 under SECOND means
# "1 credit per 10 seconds of walltime".
cost_divisor: int = 1
def __init__(
self,
cost_amount: int,
cost_type: BlockCostType = BlockCostType.RUN,
cost_filter: Optional[BlockInput] = None,
cost_divisor: int = 1,
**data: Any,
) -> None:
super().__init__(
cost_amount=cost_amount,
cost_filter=cost_filter or {},
cost_type=cost_type,
cost_divisor=max(1, cost_divisor),
**data,
)
@@ -205,31 +167,9 @@ class BlockSchema(BaseModel):
return cls.cached_jsonschema
@classmethod
def validate_data(
cls,
data: BlockInput,
exclude_fields: set[str] | None = None,
) -> str | None:
schema = cls.jsonschema()
if exclude_fields:
# Drop the excluded fields from both the properties and the
# ``required`` list so jsonschema doesn't flag them as missing.
# Used by the dry-run path to skip credentials validation while
# still validating the remaining block inputs.
schema = {
**schema,
"properties": {
k: v
for k, v in schema.get("properties", {}).items()
if k not in exclude_fields
},
"required": [
r for r in schema.get("required", []) if r not in exclude_fields
],
}
data = {k: v for k, v in data.items() if k not in exclude_fields}
def validate_data(cls, data: BlockInput) -> str | None:
return json.validate_with_jsonschema(
schema=schema,
schema=cls.jsonschema(),
data={k: v for k, v in data.items() if v is not None},
)
@@ -370,8 +310,6 @@ class BlockSchema(BaseModel):
"credentials_provider": [config.get("provider", "google")],
"credentials_types": [config.get("type", "oauth2")],
"credentials_scopes": config.get("scopes"),
"is_auto_credential": True,
"input_field_name": info["field_name"],
}
result[kwarg_name] = CredentialsFieldInfo.model_validate(
auto_schema, by_alias=True
@@ -482,6 +420,19 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
_optimized_description: ClassVar[str | None] = None
def extra_credit_charges(self, execution_stats: "NodeExecutionStats") -> int:
"""Return extra credits to charge after this block run completes.
Called by the executor after a block finishes with COMPLETED status.
The return value is the number of additional base-cost credits to
charge beyond the single credit already collected by ``_charge_usage``
at the start of execution. Defaults to 0 (no extra charges).
Override in blocks (e.g. OrchestratorBlock) that make multiple LLM
calls within one run and should be billed per call.
"""
return 0
def __init__(
self,
id: str = "",
@@ -517,6 +468,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
disabled: If the block is disabled, it will not be available for execution.
static_output: Whether the output links of the block are static by default.
"""
from backend.data.model import NodeExecutionStats
self.id = id
self.input_schema = input_schema
self.output_schema = output_schema
@@ -534,7 +487,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.is_sensitive_action = is_sensitive_action
# Read from ClassVar set by initialize_blocks()
self.optimized_description: str | None = type(self)._optimized_description
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig):
@@ -614,7 +567,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
return data
raise ValueError(f"{self.name} did not produce any output for {output}")
def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats:
def merge_stats(self, stats: "NodeExecutionStats") -> "NodeExecutionStats":
self.execution_stats += stats
return self.execution_stats
@@ -765,16 +718,11 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
# (e.g. AgentExecutorBlock) get proper input validation.
is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False)
if is_dry_run:
# Credential fields may be absent (LLM-built agents often skip
# wiring them) or nullified earlier in the pipeline. Validate
# the non-credential inputs against a schema with those fields
# excluded — stripping only the data while keeping them in the
# ``required`` list would falsely report ``'credentials' is a
# required property``.
cred_field_names = set(self.input_schema.get_credentials_fields().keys())
if error := self.input_schema.validate_data(
input_data, exclude_fields=cred_field_names
):
non_cred_data = {
k: v for k, v in input_data.items() if k not in cred_field_names
}
if error := self.input_schema.validate_data(non_cred_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
block_name=self.name,
@@ -788,61 +736,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
block_id=self.id,
)
# Ensure auto-credential kwargs are present before we hand off to
# run(). A missing auto-credential means the upstream field (e.g.
# a Google Drive picker) didn't embed a _credentials_id, or the
# executor couldn't resolve it. Without this guard, run() would
# crash with a TypeError (missing required kwarg) or an opaque
# AttributeError deep inside the provider SDK.
#
# Only raise when the field is ALSO not populated in input_data.
# ``_acquire_auto_credentials`` intentionally skips setting the
# kwarg in two legitimate cases — ``_credentials_id`` is ``None``
# (chained from upstream) or the field is missing from
# ``input_data`` at prep time (connected from upstream block).
# In both cases the upstream block is expected to populate the
# field value by execute time; raising here would break the
# documented ``AgentGoogleDriveFileInputBlock`` chaining pattern.
# Dry-run skips because the executor intentionally runs blocks
# without resolved creds for schema validation.
if not is_dry_run:
for (
kwarg_name,
info,
) in self.input_schema.get_auto_credentials_fields().items():
kwargs.setdefault(kwarg_name, None)
if kwargs[kwarg_name] is not None:
continue
# Upstream-chained pattern: the field was populated by a
# prior node (e.g. AgentGoogleDriveFileInputBlock) whose
# output carries a resolved ``_credentials_id``.
# ``_acquire_auto_credentials`` deliberately doesn't set
# the kwarg in that case because the value isn't available
# at prep time; the executor fills it in before we reach
# ``_execute``. Trust it if the ``_credentials_id`` KEY
# is present — its value may be explicitly ``None`` in
# the chained case (see sentry thread
# PRRT_kwDOJKSTjM58sJfA). Checking truthiness here would
# falsely preempt run() for every valid chained graph
# that ships ``_credentials_id=None`` in the picker
# object. Mirror ``_acquire_auto_credentials``'s own
# skip rule, which treats ``cred_id is None`` as a
# chained-skip signal.
field_name = info["field_name"]
field_value = input_data.get(field_name)
if isinstance(field_value, dict) and "_credentials_id" in field_value:
continue
raise BlockExecutionError(
message=(
f"Missing credentials for '{kwarg_name}'. "
"Select a file via the picker (which carries "
"its credentials), or connect credentials for "
"this block."
),
block_name=self.name,
block_id=self.id,
)
# Use the validated input data
async for output_name, output_data in self.run(
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),

View File

@@ -1,57 +0,0 @@
"""Provider descriptions for services that don't yet have their own ``_config.py``.
Every provider in ``_STATIC_PROVIDER_CONFIGS`` below is declared here because
its block code currently lives either in a single shared file (e.g. the 8 LLM
providers in ``blocks/llm.py``) or in a single-file block that has no dedicated
directory (e.g. ``blocks/reddit.py``).
This file gets loaded by the block auto-loader in ``blocks/__init__.py``
(``rglob("*.py")`` picks it up) so the ``ProviderBuilder(...).build()`` calls
run at startup and populate ``AutoRegistry`` before the first API request.
**Migration path:** when a provider graduates into its own directory with a
proper ``_config.py`` (following the SDK pattern, e.g. ``blocks/linear/_config.py``),
delete its entry here. The metadata will still be served by
``GET /integrations/providers`` — it just moves to live next to the provider's
auth and webhook config.
"""
from backend.data.model import CredentialsType
from backend.sdk import ProviderBuilder
_STATIC_PROVIDER_CONFIGS: dict[str, tuple[str, tuple[CredentialsType, ...]]] = {
# LLM providers that share blocks/llm.py
"aiml_api": ("Unified access to 100+ AI models", ("api_key",)),
"anthropic": ("Claude language models", ("api_key",)),
"groq": ("Fast LLM inference", ("api_key",)),
"llama_api": ("Llama model hosting", ("api_key",)),
"ollama": ("Run open-source LLMs locally", ("api_key",)),
"open_router": ("One API for every LLM", ("api_key",)),
"openai": ("GPT models and embeddings", ("api_key",)),
"v0": ("AI-generated UI components", ("api_key",)),
# Single-file providers (one provider per standalone blocks/*.py file)
"d_id": ("AI avatar and video generation", ("api_key",)),
"e2b": ("Sandboxed code execution", ("api_key",)),
"google_maps": ("Places, directions, geocoding", ("api_key",)),
"http": ("Generic HTTP requests", ("api_key", "host_scoped")),
"ideogram": ("Text-to-image generation", ("api_key",)),
"medium": ("Publish stories and posts", ("api_key",)),
"mem0": ("Long-term memory for agents", ("api_key",)),
"openweathermap": ("Weather data and forecasts", ("api_key",)),
"pinecone": ("Managed vector database", ("api_key",)),
"reddit": ("Subreddits, posts, and comments", ("oauth2",)),
"revid": ("AI-generated short-form video", ("api_key",)),
"screenshotone": ("Automated website screenshots", ("api_key",)),
"smtp": ("Send email via SMTP", ("user_password",)),
"stripe_link": ("Stripe Link wallet for agent payments", ("device_code",)),
"unreal_speech": ("Low-cost text-to-speech", ("api_key",)),
"webshare_proxy": ("Rotating proxies for scraping", ("api_key",)),
}
for _name, (_description, _auth_types) in _STATIC_PROVIDER_CONFIGS.items():
(
ProviderBuilder(_name)
.with_description(_description)
.with_supported_auth_types(*_auth_types)
.build()
)

View File

@@ -171,10 +171,7 @@ class AgentExecutorBlock(Block):
)
self.merge_stats(
NodeExecutionStats(
# Sub-graph already debited each of its own nodes; we
# roll up its total so graph_stats.cost reflects the
# full sub-graph spend.
reconciled_cost_delta=(event.stats.cost if event.stats else 0),
extra_cost=event.stats.cost if event.stats else 0,
extra_steps=event.stats.node_exec_count if event.stats else 0,
)
)

View File

@@ -4,17 +4,11 @@ Shared configuration for all AgentMail blocks.
from agentmail import AsyncAgentMail
from backend.sdk import APIKeyCredentials, BlockCostType, ProviderBuilder, SecretStr
from backend.sdk import APIKeyCredentials, ProviderBuilder, SecretStr
# AgentMail is in beta with no published paid tier yet, but ~37 blocks
# without any BLOCK_COSTS entry means they currently execute wallet-free.
# 1 cr/call is a conservative interim floor so no AgentMail work leaks
# past billing. Revisit once AgentMail publishes usage-based pricing.
agent_mail = (
ProviderBuilder("agent_mail")
.with_description("Managed email accounts for agents")
.with_api_key("AGENTMAIL_API_KEY", "AgentMail API Key")
.with_base_cost(1, BlockCostType.RUN)
.build()
)

View File

@@ -10,7 +10,6 @@ from ._webhook import AirtableWebhookManager
# Configure the Airtable provider with API key authentication
airtable = (
ProviderBuilder("airtable")
.with_description("Bases, tables, and records")
.with_api_key("AIRTABLE_API_KEY", "Airtable Personal Access Token")
.with_webhook_manager(AirtableWebhookManager)
.with_base_cost(1, BlockCostType.RUN)

View File

@@ -1,15 +0,0 @@
"""Provider registration for Apollo.
Registers the provider description shown in the settings integrations UI.
Apollo doesn't use a full :class:`ProviderBuilder` chain (auth is set up in
``_auth.py``), so this file only declares metadata.
"""
from backend.sdk import ProviderBuilder
apollo = (
ProviderBuilder("apollo")
.with_description("Sales intelligence and prospecting")
.with_supported_auth_types("api_key")
.build()
)

View File

@@ -4,10 +4,8 @@ import asyncio
import contextvars
import json
import logging
import uuid
from typing import TYPE_CHECKING, Any
from pydantic import field_validator
from typing_extensions import TypedDict # Needed for Python <3.12 compatibility
from backend.blocks._base import (
@@ -18,14 +16,12 @@ from backend.blocks._base import (
BlockSchemaOutput,
)
from backend.copilot.permissions import (
DISABLED_LEGACY_TOOL_NAMES,
CopilotPermissions,
ToolName,
all_known_tool_names,
validate_block_identifiers,
)
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
@@ -35,37 +31,6 @@ logger = logging.getLogger(__name__)
# Block ID shared between autopilot.py and copilot prompting.py.
AUTOPILOT_BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
# Identifiers used when registering an AutoPilotBlock turn with the
# stream registry — distinguishes block-originated turns from sub-session
# or HTTP SSE turns in logs / observability.
_AUTOPILOT_TOOL_CALL_ID = "autopilot_block"
_AUTOPILOT_TOOL_NAME = "autopilot_block"
# Ceiling on how long AutoPilotBlock.execute_copilot will wait for the
# enqueued turn's terminal event. Graph blocks run synchronously from
# the caller's perspective so we wait effectively as long as needed; 6h
# matches the previous abandoned-task cap and is much longer than any
# legitimate AutoPilot turn.
_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS = 6 * 60 * 60 # 6 hours
class SubAgentRecursionError(BlockExecutionError):
"""Raised when the AutoPilot sub-agent nesting depth limit is exceeded.
Inherits :class:`BlockExecutionError` — this is a known, handled
runtime failure at the block level (caller nested AutoPilotBlocks
beyond the configured limit). Surfaces with the block_name /
block_id the block framework expects, instead of being wrapped in
``BlockUnknownError``.
"""
def __init__(self, message: str) -> None:
super().__init__(
message=message,
block_name="AutoPilotBlock",
block_id=AUTOPILOT_BLOCK_ID,
)
class ToolCallEntry(TypedDict):
"""A single tool invocation record from an autopilot execution."""
@@ -200,13 +165,6 @@ class AutoPilotBlock(Block):
# timeouts internally; wrapping with asyncio.timeout corrupts the
# SDK's internal stream (see service.py CRITICAL comment).
@field_validator("tools", mode="before")
@classmethod
def strip_disabled_legacy_tools(cls, tools: Any) -> Any:
if not isinstance(tools, list):
return tools
return [tool for tool in tools if tool not in DISABLED_LEGACY_TOOL_NAMES]
class Output(BlockSchemaOutput):
"""Output schema for the AutoPilot block."""
@@ -305,15 +263,11 @@ class AutoPilotBlock(Block):
user_id: str,
permissions: "CopilotPermissions | None" = None,
) -> tuple[str, list[ToolCallEntry], str, str, TokenUsage]:
"""Invoke the copilot on the copilot_executor queue and aggregate the
result.
"""Invoke the copilot and collect all stream results.
Delegates to :func:`run_copilot_turn_via_queue` — the shared
primitive used by ``run_sub_session`` too — which creates the
stream_registry meta record, enqueues the job, and waits on the
Redis stream for the terminal event. Any available
copilot_executor worker picks up the job, so this call survives
the graph-executor worker dying mid-turn (RabbitMQ redelivers).
Delegates to :func:`collect_copilot_response` — the shared helper that
consumes ``stream_chat_completion_sdk`` without wrapping it in an
``asyncio.timeout`` (the SDK manages its own heartbeat-based timeouts).
Args:
prompt: The user task/instruction.
@@ -326,8 +280,8 @@ class AutoPilotBlock(Block):
Returns:
A tuple of (response_text, tool_calls, history_json, session_id, usage).
"""
from backend.copilot.sdk.session_waiter import (
run_copilot_turn_via_queue, # avoid circular import
from backend.copilot.sdk.collect import (
collect_copilot_response, # avoid circular import
)
tokens = _check_recursion(max_recursion_depth)
@@ -340,35 +294,14 @@ class AutoPilotBlock(Block):
if system_context:
effective_prompt = f"[System Context: {system_context}]\n\n{prompt}"
outcome, result = await run_copilot_turn_via_queue(
result = await collect_copilot_response(
session_id=session_id,
user_id=user_id,
message=effective_prompt,
# Graph block execution is synchronous from the caller's
# perspective — wait effectively as long as needed. The
# SDK enforces its own idle-based timeout inside the
# stream_registry pipeline.
timeout=_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS,
user_id=user_id,
permissions=effective_permissions,
tool_call_id=_AUTOPILOT_TOOL_CALL_ID,
tool_name=_AUTOPILOT_TOOL_NAME,
)
if outcome == "failed":
raise RuntimeError(
"AutoPilot turn failed — see the session's transcript"
)
if outcome == "running":
raise RuntimeError(
"AutoPilot turn did not complete within "
f"{_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS}s — session "
f"{session_id}"
)
# Build a lightweight conversation summary from the aggregated data.
# When ``result.queued`` is True the prompt rode on an already-
# in-flight turn (``run_copilot_turn_via_queue`` queued it and
# waited on the existing turn's stream); the aggregated result
# is still valid, so the same rendering path applies.
# Build a lightweight conversation summary from streamed data.
turn_messages: list[dict[str, Any]] = [
{"role": "user", "content": effective_prompt},
]
@@ -377,7 +310,7 @@ class AutoPilotBlock(Block):
{
"role": "assistant",
"content": result.response_text,
"tool_calls": [tc.model_dump() for tc in result.tool_calls],
"tool_calls": result.tool_calls,
}
)
else:
@@ -388,11 +321,11 @@ class AutoPilotBlock(Block):
tool_calls: list[ToolCallEntry] = [
{
"tool_call_id": tc.tool_call_id,
"tool_name": tc.tool_name,
"input": tc.input,
"output": tc.output,
"success": tc.success,
"tool_call_id": tc["tool_call_id"],
"tool_name": tc["tool_name"],
"input": tc["input"],
"output": tc["output"],
"success": tc["success"],
}
for tc in result.tool_calls
]
@@ -477,41 +410,8 @@ class AutoPilotBlock(Block):
yield "session_id", sid
yield "error", "AutoPilot execution was cancelled."
raise
except SubAgentRecursionError as exc:
# Deliberate block — re-enqueueing would immediately hit the limit
# again, so skip recovery and just surface the error.
yield "session_id", sid
yield "error", str(exc)
except Exception as exc:
yield "session_id", sid
# Recovery enqueue must happen BEFORE yielding "error": the block
# framework (_base.execute) raises BlockExecutionError immediately
# when it sees ("error", ...) and stops consuming the generator,
# so any code after that yield is dead code in production.
effective_prompt = input_data.prompt
if input_data.system_context:
effective_prompt = (
f"[System Context: {input_data.system_context}]\n\n"
f"{input_data.prompt}"
)
try:
await _enqueue_for_recovery(
sid,
execution_context.user_id,
effective_prompt,
input_data.dry_run or execution_context.dry_run,
)
except asyncio.CancelledError:
# Task cancelled during recovery — still yield the error
# so the session_id + error pair is visible before re-raising.
yield "error", str(exc)
raise
except Exception:
logger.warning(
"AutoPilot session %s: recovery enqueue raised unexpectedly",
sid[:12],
exc_info=True,
)
yield "error", str(exc)
@@ -539,13 +439,13 @@ def _check_recursion(
when the caller exits to restore the previous depth.
Raises:
SubAgentRecursionError: If the current depth already meets or exceeds the limit.
RuntimeError: If the current depth already meets or exceeds the limit.
"""
current = _autopilot_recursion_depth.get()
inherited = _autopilot_recursion_limit.get()
limit = max_depth if inherited is None else min(inherited, max_depth)
if current >= limit:
raise SubAgentRecursionError(
raise RuntimeError(
f"AutoPilot recursion depth limit reached ({limit}). "
"The autopilot has called itself too many times."
)
@@ -636,51 +536,3 @@ def _merge_inherited_permissions(
# Return the token so the caller can restore the previous value in finally.
token = _inherited_permissions.set(merged)
return merged, token
# ---------------------------------------------------------------------------
# Recovery helpers
# ---------------------------------------------------------------------------
async def _enqueue_for_recovery(
session_id: str,
user_id: str,
message: str,
dry_run: bool,
) -> None:
"""Re-enqueue an orphaned sub-agent session so a fresh executor picks it up.
When ``execute_copilot`` raises an unexpected exception the sub-agent
session is left with ``last_role=user`` and no active consumer — identical
to the state that caused Toran's reports of silent sub-agents. Publishing
the original prompt back to the copilot queue lets the executor service
resume the session without manual intervention.
Skipped for dry-run sessions (no real consumers listen to the queue for
simulated sessions). Any failure to publish is logged and swallowed so
it never masks the original exception.
"""
if dry_run:
return
try:
from backend.copilot.executor.utils import ( # avoid circular import
enqueue_copilot_turn,
)
await asyncio.wait_for(
enqueue_copilot_turn(
session_id=session_id,
user_id=user_id,
message=message,
turn_id=str(uuid.uuid4()),
),
timeout=10,
)
logger.info("AutoPilot session %s enqueued for recovery", session_id[:12])
except Exception:
logger.warning(
"AutoPilot session %s: failed to enqueue for recovery",
session_id[:12],
exc_info=True,
)

View File

@@ -62,14 +62,6 @@ class TestBuildAndValidatePermissions:
with pytest.raises(ValidationError, match="not_a_real_tool"):
_make_input(tools=["not_a_real_tool"])
async def test_disabled_legacy_tool_is_accepted_and_removed(self):
inp = _make_input(tools=["ask_question", "run_block"])
result = await _build_and_validate_permissions(inp)
assert inp.tools == ["run_block"]
assert isinstance(result, CopilotPermissions)
assert result.tools == ["run_block"]
async def test_valid_block_name_accepted(self):
mock_block_cls = MagicMock()
mock_block_cls.return_value.name = "HTTP Request"

View File

@@ -1,26 +0,0 @@
"""Shared provider config for Ayrshare social-media blocks.
The "credential" exposed to blocks is the **per-user Ayrshare profile key**,
not the org-level ``AYRSHARE_API_KEY``. Profile keys are provisioned per
user by :class:`~backend.integrations.managed_providers.ayrshare.AyrshareManagedProvider`
and stored in the normal credentials list with ``is_managed=True``, so every
Ayrshare block fits the standard credential flow:
credentials: CredentialsMetaInput = ayrshare.credentials_field(...)
``run_block`` / ``resolve_block_credentials`` take care of the rest.
``with_managed_api_key()`` registers ``api_key`` as a supported auth type
without the env-var-backed default credential that ``with_api_key()`` would
create — the org-level ``AYRSHARE_API_KEY`` is the admin key and must never
reach a block as a "profile key".
"""
from backend.sdk import ProviderBuilder
ayrshare = (
ProviderBuilder("ayrshare")
.with_description("Post to every social network")
.with_managed_api_key()
.build()
)

View File

@@ -1,18 +0,0 @@
from backend.sdk import BlockCost, BlockCostType
# Ayrshare is a subscription proxy ($149/mo Business). Per-post credit charges
# prevent a single heavy user from absorbing the fixed cost and align with the
# upload cost of each post variant.
# cost_filter matches on input_data.is_video BEFORE run() executes, so the flag
# has to be correct at input-eval time. Video-only platforms (YouTube, Snapchat)
# override the base default to True; platforms that accept both (TikTok, etc.)
# rely on the caller setting is_video explicitly for accurate billing.
# First match wins in block_usage_cost, so list the video tier first.
AYRSHARE_POST_COSTS = (
BlockCost(
cost_amount=5, cost_type=BlockCostType.RUN, cost_filter={"is_video": True}
),
BlockCost(
cost_amount=2, cost_type=BlockCostType.RUN, cost_filter={"is_video": False}
),
)

View File

@@ -4,25 +4,22 @@ from typing import Optional
from pydantic import BaseModel, Field
from backend.blocks._base import BlockSchemaInput
from backend.data.model import CredentialsMetaInput, SchemaField
from backend.data.model import SchemaField, UserIntegrations
from backend.integrations.ayrshare import AyrshareClient
from backend.util.clients import get_database_manager_async_client
from backend.util.exceptions import MissingConfigError
from ._config import ayrshare
async def get_profile_key(user_id: str):
user_integrations: UserIntegrations = (
await get_database_manager_async_client().get_user_integrations(user_id)
)
return user_integrations.managed_credentials.ayrshare_profile_key
class BaseAyrshareInput(BlockSchemaInput):
"""Base input model for Ayrshare social media posts with common fields."""
credentials: CredentialsMetaInput = ayrshare.credentials_field(
description=(
"Ayrshare profile credential. AutoGPT provisions this managed "
"credential automatically — the user does not create it. After "
"it's in place, the user links each social account via the "
"Ayrshare SSO popup in the Builder."
),
)
post: str = SchemaField(
description="The post text to be published", default="", advanced=False
)
@@ -32,9 +29,7 @@ class BaseAyrshareInput(BlockSchemaInput):
advanced=False,
)
is_video: bool = SchemaField(
description="Whether the media is a video. Set to True when uploading a video so billing applies the video tier.",
default=False,
advanced=True,
description="Whether the media is a video", default=False, advanced=True
)
schedule_date: Optional[datetime] = SchemaField(
description="UTC datetime for scheduling (YYYY-MM-DDThh:mm:ssZ)",

View File

@@ -1,20 +1,16 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
@cost(*AYRSHARE_POST_COSTS)
class PostToBlueskyBlock(Block):
"""Block for posting to Bluesky with Bluesky-specific options."""
@@ -61,10 +57,16 @@ class PostToBlueskyBlock(Block):
self,
input_data: "PostToBlueskyBlock.Input",
*,
credentials: APIKeyCredentials,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Bluesky with Bluesky-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -104,7 +106,7 @@ class PostToBlueskyBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
bluesky_options=bluesky_options if bluesky_options else None,
profile_key=credentials.api_key.get_secret_value(),
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -1,20 +1,21 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, CarouselItem, create_ayrshare_client
from ._util import (
BaseAyrshareInput,
CarouselItem,
create_ayrshare_client,
get_profile_key,
)
@cost(*AYRSHARE_POST_COSTS)
class PostToFacebookBlock(Block):
"""Block for posting to Facebook with Facebook-specific options."""
@@ -119,10 +120,15 @@ class PostToFacebookBlock(Block):
self,
input_data: "PostToFacebookBlock.Input",
*,
credentials: APIKeyCredentials,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Facebook with Facebook-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -198,7 +204,7 @@ class PostToFacebookBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
facebook_options=facebook_options if facebook_options else None,
profile_key=credentials.api_key.get_secret_value(),
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -1,20 +1,16 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
@cost(*AYRSHARE_POST_COSTS)
class PostToGMBBlock(Block):
"""Block for posting to Google My Business with GMB-specific options."""
@@ -114,13 +110,14 @@ class PostToGMBBlock(Block):
)
async def run(
self,
input_data: "PostToGMBBlock.Input",
*,
credentials: APIKeyCredentials,
**kwargs
self, input_data: "PostToGMBBlock.Input", *, user_id: str, **kwargs
) -> BlockOutput:
"""Post to Google My Business with GMB-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -205,7 +202,7 @@ class PostToGMBBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
gmb_options=gmb_options if gmb_options else None,
profile_key=credentials.api_key.get_secret_value(),
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -2,21 +2,22 @@ from typing import Any
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, InstagramUserTag, create_ayrshare_client
from ._util import (
BaseAyrshareInput,
InstagramUserTag,
create_ayrshare_client,
get_profile_key,
)
@cost(*AYRSHARE_POST_COSTS)
class PostToInstagramBlock(Block):
"""Block for posting to Instagram with Instagram-specific options."""
@@ -111,10 +112,15 @@ class PostToInstagramBlock(Block):
self,
input_data: "PostToInstagramBlock.Input",
*,
credentials: APIKeyCredentials,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Instagram with Instagram-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -235,7 +241,7 @@ class PostToInstagramBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
instagram_options=instagram_options if instagram_options else None,
profile_key=credentials.api_key.get_secret_value(),
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -1,20 +1,16 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
@cost(*AYRSHARE_POST_COSTS)
class PostToLinkedInBlock(Block):
"""Block for posting to LinkedIn with LinkedIn-specific options."""
@@ -116,10 +112,15 @@ class PostToLinkedInBlock(Block):
self,
input_data: "PostToLinkedInBlock.Input",
*,
credentials: APIKeyCredentials,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to LinkedIn with LinkedIn-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -213,7 +214,7 @@ class PostToLinkedInBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
linkedin_options=linkedin_options if linkedin_options else None,
profile_key=credentials.api_key.get_secret_value(),
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -1,20 +1,21 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, PinterestCarouselOption, create_ayrshare_client
from ._util import (
BaseAyrshareInput,
PinterestCarouselOption,
create_ayrshare_client,
get_profile_key,
)
@cost(*AYRSHARE_POST_COSTS)
class PostToPinterestBlock(Block):
"""Block for posting to Pinterest with Pinterest-specific options."""
@@ -91,10 +92,15 @@ class PostToPinterestBlock(Block):
self,
input_data: "PostToPinterestBlock.Input",
*,
credentials: APIKeyCredentials,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Pinterest with Pinterest-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -200,7 +206,7 @@ class PostToPinterestBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
pinterest_options=pinterest_options if pinterest_options else None,
profile_key=credentials.api_key.get_secret_value(),
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -1,20 +1,16 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
@cost(*AYRSHARE_POST_COSTS)
class PostToRedditBlock(Block):
"""Block for posting to Reddit."""
@@ -39,12 +35,12 @@ class PostToRedditBlock(Block):
)
async def run(
self,
input_data: "PostToRedditBlock.Input",
*,
credentials: APIKeyCredentials,
**kwargs
self, input_data: "PostToRedditBlock.Input", *, user_id: str, **kwargs
) -> BlockOutput:
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured."
@@ -65,7 +61,7 @@ class PostToRedditBlock(Block):
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
profile_key=credentials.api_key.get_secret_value(),
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -1,20 +1,16 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
@cost(*AYRSHARE_POST_COSTS)
class PostToSnapchatBlock(Block):
"""Block for posting to Snapchat with Snapchat-specific options."""
@@ -35,14 +31,6 @@ class PostToSnapchatBlock(Block):
advanced=False,
)
# Snapchat is video-only; override the base default so the @cost filter
# selects the 5-credit video tier instead of the 2-credit image tier.
is_video: bool = SchemaField(
description="Whether the media is a video (always True for Snapchat)",
default=True,
advanced=True,
)
# Snapchat-specific options
story_type: str = SchemaField(
description="Type of Snapchat content: 'story' (24-hour Stories), 'saved_story' (Saved Stories), or 'spotlight' (Spotlight posts)",
@@ -74,10 +62,15 @@ class PostToSnapchatBlock(Block):
self,
input_data: "PostToSnapchatBlock.Input",
*,
credentials: APIKeyCredentials,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Snapchat with Snapchat-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -128,7 +121,7 @@ class PostToSnapchatBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
snapchat_options=snapchat_options if snapchat_options else None,
profile_key=credentials.api_key.get_secret_value(),
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -1,20 +1,16 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
@cost(*AYRSHARE_POST_COSTS)
class PostToTelegramBlock(Block):
"""Block for posting to Telegram with Telegram-specific options."""
@@ -61,10 +57,15 @@ class PostToTelegramBlock(Block):
self,
input_data: "PostToTelegramBlock.Input",
*,
credentials: APIKeyCredentials,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Telegram with Telegram-specific validation."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -107,7 +108,7 @@ class PostToTelegramBlock(Block):
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
profile_key=credentials.api_key.get_secret_value(),
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -1,20 +1,16 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
@cost(*AYRSHARE_POST_COSTS)
class PostToThreadsBlock(Block):
"""Block for posting to Threads with Threads-specific options."""
@@ -54,10 +50,15 @@ class PostToThreadsBlock(Block):
self,
input_data: "PostToThreadsBlock.Input",
*,
credentials: APIKeyCredentials,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Threads with Threads-specific validation."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -102,7 +103,7 @@ class PostToThreadsBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
threads_options=threads_options if threads_options else None,
profile_key=credentials.api_key.get_secret_value(),
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -2,18 +2,15 @@ from enum import Enum
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class TikTokVisibility(str, Enum):
@@ -22,7 +19,6 @@ class TikTokVisibility(str, Enum):
FOLLOWERS = "followers"
@cost(*AYRSHARE_POST_COSTS)
class PostToTikTokBlock(Block):
"""Block for posting to TikTok with TikTok-specific options."""
@@ -117,13 +113,14 @@ class PostToTikTokBlock(Block):
)
async def run(
self,
input_data: "PostToTikTokBlock.Input",
*,
credentials: APIKeyCredentials,
**kwargs,
self, input_data: "PostToTikTokBlock.Input", *, user_id: str, **kwargs
) -> BlockOutput:
"""Post to TikTok with TikTok-specific validation and options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -238,7 +235,7 @@ class PostToTikTokBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
tiktok_options=tiktok_options if tiktok_options else None,
profile_key=credentials.api_key.get_secret_value(),
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -1,20 +1,16 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
@cost(*AYRSHARE_POST_COSTS)
class PostToXBlock(Block):
"""Block for posting to X / Twitter with Twitter-specific options."""
@@ -119,10 +115,15 @@ class PostToXBlock(Block):
self,
input_data: "PostToXBlock.Input",
*,
credentials: APIKeyCredentials,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to X / Twitter with enhanced X-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -232,7 +233,7 @@ class PostToXBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
twitter_options=twitter_options if twitter_options else None,
profile_key=credentials.api_key.get_secret_value(),
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -3,18 +3,15 @@ from typing import Any
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class YouTubeVisibility(str, Enum):
@@ -23,7 +20,6 @@ class YouTubeVisibility(str, Enum):
UNLISTED = "unlisted"
@cost(*AYRSHARE_POST_COSTS)
class PostToYouTubeBlock(Block):
"""Block for posting to YouTube with YouTube-specific options."""
@@ -43,14 +39,6 @@ class PostToYouTubeBlock(Block):
advanced=False,
)
# YouTube is video-only; override the base default so the @cost filter
# selects the 5-credit video tier instead of the 2-credit image tier.
is_video: bool = SchemaField(
description="Whether the media is a video (always True for YouTube)",
default=True,
advanced=True,
)
# YouTube-specific required options
title: str = SchemaField(
description="Video title (max 100 chars, required). Cannot contain < or > characters.",
@@ -149,10 +137,16 @@ class PostToYouTubeBlock(Block):
self,
input_data: "PostToYouTubeBlock.Input",
*,
credentials: APIKeyCredentials,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to YouTube with YouTube-specific validation and options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -308,7 +302,7 @@ class PostToYouTubeBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
youtube_options=youtube_options,
profile_key=credentials.api_key.get_secret_value(),
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -7,7 +7,6 @@ from backend.sdk import BlockCostType, ProviderBuilder
# Configure the Meeting BaaS provider with API key authentication
baas = (
ProviderBuilder("baas")
.with_description("Meeting recording and transcription")
.with_api_key("MEETING_BAAS_API_KEY", "Meeting BaaS API Key")
.with_base_cost(5, BlockCostType.RUN) # Higher cost for meeting recording service
.build()

View File

@@ -4,34 +4,21 @@ Meeting BaaS bot (recording) blocks.
from typing import Optional
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockCost,
BlockCostType,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
CredentialsMetaInput,
SchemaField,
cost,
)
from ._api import MeetingBaasAPI
from ._config import baas
# Meeting BaaS recording rate: $0.69 per hour.
_MEETING_BAAS_USD_PER_SECOND = 0.69 / 3600
# Join bills a flat 30 cr commit (covers median short meeting);
# FetchMeetingData bills the duration-scaled remainder from the
# `duration_seconds` field on the API response. Long meetings no
# longer under-bill.
@cost(BlockCost(cost_type=BlockCostType.RUN, cost_amount=30))
class BaasBotJoinMeetingBlock(Block):
"""
Deploy a bot immediately or at a scheduled start_time to join and record a meeting.
@@ -147,7 +134,6 @@ class BaasBotLeaveMeetingBlock(Block):
yield "left", left
@cost(BlockCost(cost_type=BlockCostType.COST_USD, cost_amount=150))
class BaasBotFetchMeetingDataBlock(Block):
"""
Pull MP4 URL, transcript & metadata for a completed meeting.
@@ -190,21 +176,9 @@ class BaasBotFetchMeetingDataBlock(Block):
include_transcripts=input_data.include_transcripts,
)
bot_meta = data.get("bot_data", {}).get("bot", {}) or {}
# Bill recording duration via COST_USD so multi-hour meetings
# scale past the Join block's flat 30 cr deposit.
duration_seconds = float(bot_meta.get("duration_seconds") or 0)
if duration_seconds > 0:
self.merge_stats(
NodeExecutionStats(
provider_cost=duration_seconds * _MEETING_BAAS_USD_PER_SECOND,
provider_cost_type="cost_usd",
)
)
yield "mp4_url", data.get("mp4", "")
yield "transcript", data.get("bot_data", {}).get("transcripts", [])
yield "metadata", bot_meta
yield "metadata", data.get("bot_data", {}).get("bot", {})
class BaasBotDeleteRecordingBlock(Block):

View File

@@ -1,86 +0,0 @@
"""Unit tests for Meeting BaaS duration-based cost emission."""
from unittest.mock import AsyncMock, patch
import pytest
from pydantic import SecretStr
from backend.blocks.baas.bots import (
_MEETING_BAAS_USD_PER_SECOND,
BaasBotFetchMeetingDataBlock,
)
from backend.data.model import APIKeyCredentials, NodeExecutionStats
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="baas",
title="Mock BaaS API Key",
api_key=SecretStr("mock-baas-api-key"),
expires_at=None,
)
def test_usd_per_second_derives_from_published_rate():
"""$0.69/hour published rate → ~$0.000192/second."""
assert _MEETING_BAAS_USD_PER_SECOND == pytest.approx(0.69 / 3600)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"duration_seconds, expected_usd",
[
(3600, 0.69), # 1 hour
(1800, 0.345), # 30 min
(0, None), # no recording → no emission
(None, None), # missing duration field → no emission
],
)
async def test_fetch_meeting_data_emits_duration_cost_usd(
duration_seconds, expected_usd
):
"""FetchMeetingData extracts duration_seconds from bot metadata and
emits provider_cost / cost_usd scaled by the published $0.69/hr rate.
Emission is skipped when duration is 0 or missing.
"""
block = BaasBotFetchMeetingDataBlock()
bot_meta = {"id": "bot-xyz"}
if duration_seconds is not None:
bot_meta["duration_seconds"] = duration_seconds
mock_api = AsyncMock()
mock_api.get_meeting_data.return_value = {
"mp4": "https://example/recording.mp4",
"bot_data": {"bot": bot_meta, "transcripts": []},
}
captured: list[NodeExecutionStats] = []
with (
patch("backend.blocks.baas.bots.MeetingBaasAPI", return_value=mock_api),
patch.object(block, "merge_stats", side_effect=captured.append),
):
outputs = []
async for name, val in block.run(
block.input_schema(
credentials={
"id": TEST_CREDENTIALS.id,
"provider": TEST_CREDENTIALS.provider,
"type": TEST_CREDENTIALS.type,
},
bot_id="bot-xyz",
include_transcripts=False,
),
credentials=TEST_CREDENTIALS,
):
outputs.append((name, val))
# Always yields the 3 outputs regardless of duration.
names = [n for n, _ in outputs]
assert "mp4_url" in names and "metadata" in names
if expected_usd is None:
assert captured == []
else:
assert len(captured) == 1
assert captured[0].provider_cost == pytest.approx(expected_usd)
assert captured[0].provider_cost_type == "cost_usd"

View File

@@ -2,8 +2,7 @@ from backend.sdk import BlockCostType, ProviderBuilder
bannerbear = (
ProviderBuilder("bannerbear")
.with_description("Auto-generate images and videos")
.with_api_key("BANNERBEAR_API_KEY", "Bannerbear API Key")
.with_base_cost(3, BlockCostType.RUN)
.with_base_cost(1, BlockCostType.RUN)
.build()
)

View File

@@ -433,7 +433,7 @@ class TestJinaEmbeddingBlockCostTracking:
class TestUnrealTextToSpeechBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_character_count(self):
"""provider_cost = len(text) * $0.000016 with type='cost_usd'."""
"""provider_cost equals len(text) with type='characters'."""
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
from backend.blocks.text_to_speech_block import (
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
@@ -461,12 +461,12 @@ class TestUnrealTextToSpeechBlockCostTracking:
mock_merge.assert_called_once()
stats = mock_merge.call_args[0][0]
assert stats.provider_cost == pytest.approx(len(test_text) * 0.000016)
assert stats.provider_cost_type == "cost_usd"
assert stats.provider_cost == float(len(test_text))
assert stats.provider_cost_type == "characters"
@pytest.mark.asyncio
async def test_empty_text_gives_zero_characters(self):
"""An empty text string results in provider_cost=0.0 (cost_usd)."""
"""An empty text string results in provider_cost=0.0."""
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
from backend.blocks.text_to_speech_block import (
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
@@ -494,7 +494,7 @@ class TestUnrealTextToSpeechBlockCostTracking:
mock_merge.assert_called_once()
stats = mock_merge.call_args[0][0]
assert stats.provider_cost == 0.0
assert stats.provider_cost_type == "cost_usd"
assert stats.provider_cost_type == "characters"
# ---------------------------------------------------------------------------

View File

@@ -17,7 +17,6 @@ from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
@@ -432,7 +431,6 @@ class ClaudeCodeBlock(Block):
# The JSON output contains the result
output_data = json.loads(raw_output)
response = output_data.get("result", raw_output)
self._record_cli_cost(output_data)
# Build conversation history entry
turn_entry = f"User: {prompt}\nClaude: {response}"
@@ -486,23 +484,6 @@ class ClaudeCodeBlock(Block):
escaped = prompt.replace("'", "'\"'\"'")
return f"'{escaped}'"
def _record_cli_cost(self, output_data: dict) -> None:
"""Feed Claude Code CLI's `total_cost_usd` to the COST_USD resolver.
The CLI rolls up Anthropic LLM + internal tool-call spend into
``total_cost_usd`` on its JSON response; piping it through
``merge_stats`` lets the wallet reflect real spend.
"""
total_cost_usd = output_data.get("total_cost_usd")
if total_cost_usd is None:
return
self.merge_stats(
NodeExecutionStats(
provider_cost=float(total_cost_usd),
provider_cost_type="cost_usd",
)
)
async def run(
self,
input_data: Input,

View File

@@ -1,106 +0,0 @@
"""Unit tests for ClaudeCodeBlock COST_USD billing migration.
Verifies:
- Block emits provider_cost / cost_usd when Claude Code CLI returns
total_cost_usd.
- block_usage_cost resolves the COST_USD entry to the expected ceil(usd *
cost_amount) credit charge.
- Missing total_cost_usd gracefully produces provider_cost=None (no bill).
"""
from unittest.mock import MagicMock, patch
import pytest
from backend.blocks._base import BlockCostType
from backend.blocks.claude_code import ClaudeCodeBlock
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.model import NodeExecutionStats
from backend.executor.utils import block_usage_cost
def test_claude_code_registered_as_cost_usd_150():
"""Sanity: BLOCK_COSTS holds the COST_USD, 150 cr/$ entry."""
entries = BLOCK_COSTS[ClaudeCodeBlock]
assert len(entries) == 1
entry = entries[0]
assert entry.cost_type == BlockCostType.COST_USD
assert entry.cost_amount == 150
@pytest.mark.parametrize(
"total_cost_usd, expected_credits",
[
(0.50, 75), # $0.50 × 150 = 75 cr
(1.00, 150), # $1.00 × 150 = 150 cr
(0.0134, 3), # ceil(0.0134 × 150) = ceil(2.01) = 3
(2.00, 300), # $2 × 150 = 300 cr
(0.001, 1), # ceil(0.001 × 150) = ceil(0.15) = 1 — no 0-cr leak on
# sub-cent runs
],
)
def test_cost_usd_resolver_applies_150_multiplier(total_cost_usd, expected_credits):
"""block_usage_cost with cost_usd stats returns ceil(usd * 150)."""
block = ClaudeCodeBlock()
# cost_filter requires matching e2b_credentials; supply the ones the
# registration uses so _is_cost_filter_match accepts the input.
entry = BLOCK_COSTS[ClaudeCodeBlock][0]
input_data = {"e2b_credentials": entry.cost_filter["e2b_credentials"]}
stats = NodeExecutionStats(
provider_cost=total_cost_usd,
provider_cost_type="cost_usd",
)
cost, matching_filter = block_usage_cost(
block=block, input_data=input_data, stats=stats
)
assert cost == expected_credits
assert matching_filter == entry.cost_filter
def test_cost_usd_resolver_returns_zero_when_stats_missing_cost():
"""Pre-flight (no stats) or unbilled run (provider_cost None) → 0."""
block = ClaudeCodeBlock()
entry = BLOCK_COSTS[ClaudeCodeBlock][0]
input_data = {"e2b_credentials": entry.cost_filter["e2b_credentials"]}
# No stats at all → pre-flight path, returns 0.
pre_cost, _ = block_usage_cost(block=block, input_data=input_data)
assert pre_cost == 0
# Stats present but no provider_cost → resolver can't bill.
stats = NodeExecutionStats()
post_cost, _ = block_usage_cost(block=block, input_data=input_data, stats=stats)
assert post_cost == 0
def test_record_cli_cost_emits_provider_cost_when_total_cost_present():
"""``_record_cli_cost`` (the helper called from ``execute_claude_code``)
must emit a single ``merge_stats`` with provider_cost + cost_usd tag
when the CLI JSON payload carries ``total_cost_usd``.
"""
block = ClaudeCodeBlock()
captured: list[NodeExecutionStats] = []
with patch.object(block, "merge_stats", side_effect=captured.append):
block._record_cli_cost(
{
"result": "hello from claude",
"total_cost_usd": 0.0421,
"usage": {"input_tokens": 1234, "output_tokens": 56},
}
)
assert len(captured) == 1
stats = captured[0]
assert stats.provider_cost == pytest.approx(0.0421)
assert stats.provider_cost_type == "cost_usd"
def test_record_cli_cost_skips_merge_when_total_cost_absent():
"""If the CLI payload lacks ``total_cost_usd`` (legacy / non-JSON
output), ``_record_cli_cost`` must not call ``merge_stats`` — otherwise
we'd pollute telemetry with a ``cost_usd`` emission that has no real
cost attached.
"""
block = ClaudeCodeBlock()
mock = MagicMock()
with patch.object(block, "merge_stats", mock):
block._record_cli_cost({"result": "hello"})
mock.assert_not_called()

View File

@@ -151,17 +151,6 @@ class CodeGenerationBlock(Block):
)
self.execution_stats = NodeExecutionStats()
# GPT-5.1-Codex published pricing: $1.25 / 1M input, $10 / 1M output.
_INPUT_USD_PER_1M = 1.25
_OUTPUT_USD_PER_1M = 10.0
@staticmethod
def _compute_token_usd(input_tokens: int, output_tokens: int) -> float:
return (
input_tokens * CodeGenerationBlock._INPUT_USD_PER_1M
+ output_tokens * CodeGenerationBlock._OUTPUT_USD_PER_1M
) / 1_000_000
async def call_codex(
self,
*,
@@ -200,15 +189,13 @@ class CodeGenerationBlock(Block):
response_id = response.id or ""
# Update usage stats
input_tokens = response.usage.input_tokens if response.usage else 0
output_tokens = response.usage.output_tokens if response.usage else 0
self.execution_stats.input_token_count = input_tokens
self.execution_stats.output_token_count = output_tokens
self.execution_stats.llm_call_count += 1
self.execution_stats.provider_cost = self._compute_token_usd(
input_tokens, output_tokens
self.execution_stats.input_token_count = (
response.usage.input_tokens if response.usage else 0
)
self.execution_stats.provider_cost_type = "cost_usd"
self.execution_stats.output_token_count = (
response.usage.output_tokens if response.usage else 0
)
self.execution_stats.llm_call_count += 1
return CodexCallResult(
response=text_output,

View File

@@ -1,10 +0,0 @@
"""Provider registration for Compass — metadata only (auth lives elsewhere)."""
from backend.sdk import ProviderBuilder
compass = (
ProviderBuilder("compass")
.with_description("Geospatial context for agents")
.with_supported_auth_types("api_key")
.build()
)

View File

@@ -1,226 +0,0 @@
"""Coverage tests for the cost-leak fixes in this PR.
Each block's ``run()`` / helper emits provider_cost + cost_usd (or items)
via merge_stats so the post-flight resolver bills real provider spend.
Tests here drive that emission path directly so a regression on any one
block surfaces immediately.
"""
from unittest.mock import patch
import pytest
from pydantic import SecretStr
from backend.blocks._base import BlockCostType
from backend.blocks.ai_condition import AIConditionBlock
from backend.data.block_cost_config import BLOCK_COSTS, LLM_COST
from backend.data.model import APIKeyCredentials, NodeExecutionStats
# -------- AIConditionBlock registration --------
def test_ai_condition_registered_under_llm_cost():
"""AIConditionBlock was running wallet-free before this PR; verify it
now resolves through the same per-model LLM_COST table as every other
LLM block.
"""
assert BLOCK_COSTS[AIConditionBlock] is LLM_COST
# -------- Pinecone insert ITEMS emission --------
@pytest.mark.asyncio
async def test_pinecone_insert_emits_items_provider_cost():
from backend.blocks.pinecone import PineconeInsertBlock
block = PineconeInsertBlock()
captured: list[NodeExecutionStats] = []
class _FakeIndex:
def upsert(self, **_):
return None
class _FakePinecone:
def __init__(self, *_, **__):
pass
def Index(self, _name):
return _FakeIndex()
with (
patch("backend.blocks.pinecone.Pinecone", _FakePinecone),
patch.object(block, "merge_stats", side_effect=captured.append),
):
input_data = block.input_schema(
credentials={
"id": "00000000-0000-0000-0000-000000000000",
"provider": "pinecone",
"type": "api_key",
},
index="my-index",
chunks=["alpha", "beta", "gamma"],
embeddings=[[0.1] * 4, [0.2] * 4, [0.3] * 4],
namespace="",
metadata={},
)
creds = APIKeyCredentials(
id="00000000-0000-0000-0000-000000000000",
provider="pinecone",
title="mock",
api_key=SecretStr("mock-key"),
expires_at=None,
)
outputs = [(n, v) async for n, v in block.run(input_data, credentials=creds)]
assert any(name == "upsert_response" for name, _ in outputs)
assert len(captured) == 1
stats = captured[0]
assert stats.provider_cost == pytest.approx(3.0)
assert stats.provider_cost_type == "items"
# -------- Narration model-aware per-char rate --------
@pytest.mark.parametrize(
"model_id, expected_rate_per_char",
[
("eleven_flash_v2_5", 0.000167 * 0.5),
("eleven_turbo_v2_5", 0.000167 * 0.5),
("eleven_multilingual_v2", 0.000167 * 1.0),
("eleven_turbo_v2", 0.000167 * 1.0),
],
)
def test_narration_per_char_rate_scales_with_model(model_id, expected_rate_per_char):
"""Drive VideoNarrationBlock._record_script_cost directly so a regression
that drops the model-aware branching (e.g. hardcoding 1.0 cr/char for
all models) makes this test fail.
"""
from backend.blocks.video.narration import VideoNarrationBlock
block = VideoNarrationBlock()
captured: list[NodeExecutionStats] = []
with patch.object(block, "merge_stats", side_effect=captured.append):
block._record_script_cost("x" * 5000, model_id)
assert len(captured) == 1
stats = captured[0]
assert stats.provider_cost == pytest.approx(5000 * expected_rate_per_char)
assert stats.provider_cost_type == "cost_usd"
# -------- Perplexity None-guard on x-total-cost --------
@pytest.mark.parametrize(
"openrouter_cost, expect_type",
[
(0.0421, "cost_usd"), # concrete positive USD → tagged
(None, None), # header missing → no tag (keeps gap observable)
(0.0, None), # zero → no tag (wouldn't bill anything anyway)
],
)
def test_perplexity_record_openrouter_cost_tags_only_on_concrete_value(
openrouter_cost, expect_type
):
"""Drive PerplexityBlock._record_openrouter_cost directly to verify the
None/0 guard. A regression that tags cost_usd unconditionally would
silently floor the user's bill to 0 via the resolver — this test
would catch it.
"""
from backend.blocks.perplexity import PerplexityBlock
block = PerplexityBlock()
with patch(
"backend.blocks.perplexity.extract_openrouter_cost",
return_value=openrouter_cost,
):
block._record_openrouter_cost(response=object())
assert block.execution_stats.provider_cost == openrouter_cost
assert block.execution_stats.provider_cost_type == expect_type
# -------- Codex COST_USD registration --------
def test_codex_registered_as_cost_usd_150():
from backend.blocks.codex import CodeGenerationBlock
entries = BLOCK_COSTS[CodeGenerationBlock]
assert len(entries) == 1
entry = entries[0]
assert entry.cost_type == BlockCostType.COST_USD
assert entry.cost_amount == 150
@pytest.mark.parametrize(
"input_tokens, output_tokens, expected_usd",
[
# GPT-5.1-Codex: $1.25 / 1M input, $10 / 1M output.
(1_000_000, 0, 1.25),
(0, 1_000_000, 10.0),
(100_000, 10_000, 0.225), # 0.125 + 0.100
(0, 0, 0.0),
],
)
def test_codex_computes_provider_cost_usd_from_token_counts(
input_tokens, output_tokens, expected_usd
):
"""Drive CodeGenerationBlock._compute_token_usd directly. A regression
to the wrong rate constants (e.g. swapping the $1.25 input rate for
GPT-4o's $2.50) would fail this test.
"""
from backend.blocks.codex import CodeGenerationBlock
assert CodeGenerationBlock._compute_token_usd(
input_tokens, output_tokens
) == pytest.approx(expected_usd)
# -------- ClaudeCode COST_USD registration sanity (already tested in claude_code_cost_test.py) --------
# -------- Perplexity COST_USD registration for all 3 tiers --------
def test_perplexity_sonar_all_tiers_registered_as_cost_usd_150():
from backend.blocks.perplexity import PerplexityBlock
entries = BLOCK_COSTS[PerplexityBlock]
# 3 tiers (SONAR, SONAR_PRO, SONAR_DEEP_RESEARCH) all COST_USD 150.
assert len(entries) == 3
for entry in entries:
assert entry.cost_type == BlockCostType.COST_USD
assert entry.cost_amount == 150
# -------- Narration COST_USD registration --------
def test_narration_registered_as_cost_usd_150():
from backend.blocks.video.narration import VideoNarrationBlock
entries = BLOCK_COSTS[VideoNarrationBlock]
assert len(entries) == 1
assert entries[0].cost_type == BlockCostType.COST_USD
assert entries[0].cost_amount == 150
# -------- Pinecone registrations --------
def test_pinecone_registrations():
from backend.blocks.pinecone import (
PineconeInitBlock,
PineconeInsertBlock,
PineconeQueryBlock,
)
assert BLOCK_COSTS[PineconeInitBlock][0].cost_type == BlockCostType.RUN
assert BLOCK_COSTS[PineconeQueryBlock][0].cost_type == BlockCostType.RUN
# Insert scales with item count.
assert BLOCK_COSTS[PineconeInsertBlock][0].cost_type == BlockCostType.ITEMS
assert BLOCK_COSTS[PineconeInsertBlock][0].cost_amount == 1

View File

@@ -19,10 +19,6 @@ class DataForSeoClient:
trusted_origins=["https://api.dataforseo.com"],
raise_for_status=False,
)
# USD cost reported by DataForSEO on the most recent successful call.
# Populated by keyword_suggestions / related_keywords so the caller
# can surface it via NodeExecutionStats.provider_cost for billing.
self.last_cost_usd: float = 0.0
def _get_headers(self) -> Dict[str, str]:
"""Generate the authorization header using Basic Auth."""
@@ -101,9 +97,6 @@ class DataForSeoClient:
if data.get("tasks") and len(data["tasks"]) > 0:
task = data["tasks"][0]
if task.get("status_code") == 20000: # Success code
# DataForSEO reports per-task USD cost; stash it so callers
# can populate NodeExecutionStats.provider_cost.
self.last_cost_usd = float(task.get("cost") or 0.0)
return task.get("result", [])
else:
error_msg = task.get("status_message", "Task failed")
@@ -181,9 +174,6 @@ class DataForSeoClient:
if data.get("tasks") and len(data["tasks"]) > 0:
task = data["tasks"][0]
if task.get("status_code") == 20000: # Success code
# DataForSEO reports per-task USD cost; stash it so callers
# can populate NodeExecutionStats.provider_cost.
self.last_cost_usd = float(task.get("cost") or 0.0)
return task.get("result", [])
else:
error_msg = task.get("status_message", "Task failed")

View File

@@ -7,17 +7,11 @@ from backend.sdk import BlockCostType, ProviderBuilder
# Build the DataForSEO provider with username/password authentication
dataforseo = (
ProviderBuilder("dataforseo")
.with_description("SEO and SERP data")
.with_user_password(
username_env_var="DATAFORSEO_USERNAME",
password_env_var="DATAFORSEO_PASSWORD",
title="DataForSEO Credentials",
)
# DataForSEO reports USD cost per task (e.g. $0.001/keyword returned).
# DataForSeoClient stashes it on last_cost_usd; each block emits it via
# merge_stats so the COST_USD resolver bills against real spend.
# 1000 platform credits per USD → 1 credit per $0.001 (≈ 1 credit/
# returned keyword on the standard tier).
.with_base_cost(1000, BlockCostType.COST_USD)
.with_base_cost(1, BlockCostType.RUN)
.build()
)

View File

@@ -4,7 +4,6 @@ DataForSEO Google Keyword Suggestions block.
from typing import Any, Dict, List, Optional
from backend.data.model import NodeExecutionStats
from backend.sdk import (
Block,
BlockCategory,
@@ -111,10 +110,8 @@ class DataForSeoKeywordSuggestionsBlock(Block):
test_output=[
(
"suggestion",
lambda x: (
hasattr(x, "keyword")
and x.keyword == "digital marketing strategy"
),
lambda x: hasattr(x, "keyword")
and x.keyword == "digital marketing strategy",
),
("suggestions", lambda x: isinstance(x, list) and len(x) == 1),
("total_count", 1),
@@ -170,16 +167,6 @@ class DataForSeoKeywordSuggestionsBlock(Block):
results = await self._fetch_keyword_suggestions(client, input_data)
# DataForSEO reports per-task USD cost on the response. Feed it
# into NodeExecutionStats so the COST_USD resolver bills the
# real provider spend at reconciliation time.
self.merge_stats(
NodeExecutionStats(
provider_cost=client.last_cost_usd,
provider_cost_type="cost_usd",
)
)
# Process and format the results
suggestions = []
if results and len(results) > 0:

View File

@@ -4,7 +4,6 @@ DataForSEO Google Related Keywords block.
from typing import Any, Dict, List, Optional
from backend.data.model import NodeExecutionStats
from backend.sdk import (
Block,
BlockCategory,
@@ -178,16 +177,6 @@ class DataForSeoRelatedKeywordsBlock(Block):
results = await self._fetch_related_keywords(client, input_data)
# DataForSEO reports per-task USD cost on the response. Feed it
# into NodeExecutionStats so the COST_USD resolver bills the
# real provider spend at reconciliation time.
self.merge_stats(
NodeExecutionStats(
provider_cost=client.last_cost_usd,
provider_cost_type="cost_usd",
)
)
# Process and format the results
related_keywords = []
if results and len(results) > 0:

View File

@@ -1,10 +0,0 @@
"""Provider registration for Discord — metadata only (auth lives in ``_auth.py``)."""
from backend.sdk import ProviderBuilder
discord = (
ProviderBuilder("discord")
.with_description("Messages, channels, and servers")
.with_supported_auth_types("api_key", "oauth2")
.build()
)

View File

@@ -1,10 +0,0 @@
"""Provider registration for ElevenLabs — metadata only (auth lives in ``_auth.py``)."""
from backend.sdk import ProviderBuilder
elevenlabs = (
ProviderBuilder("elevenlabs")
.with_description("Realistic AI voice synthesis")
.with_supported_auth_types("api_key")
.build()
)

View File

@@ -1,10 +0,0 @@
"""Provider registration for Enrichlayer — metadata only (auth lives in ``_auth.py``)."""
from backend.sdk import ProviderBuilder
enrichlayer = (
ProviderBuilder("enrichlayer")
.with_description("Enrich leads with company data")
.with_supported_auth_types("api_key")
.build()
)

View File

@@ -9,14 +9,8 @@ from ._webhook import ExaWebhookManager
# Configure the Exa provider once for all blocks
exa = (
ProviderBuilder("exa")
.with_description("Neural web search")
.with_api_key("EXA_API_KEY", "Exa API Key")
.with_webhook_manager(ExaWebhookManager)
# Exa returns `cost_dollars.total` on every response and ExaSearchBlock
# (plus ~45 sibling blocks that share this provider config) already
# populates NodeExecutionStats.provider_cost with it. Bill 100 credits
# per USD (~$0.01/credit): cheap searches stay at 12 credits, a Deep
# Research run at $0.20 lands at 20 credits, matching provider spend.
.with_base_cost(100, BlockCostType.COST_USD)
.with_base_cost(1, BlockCostType.RUN)
.build()
)

View File

@@ -17,7 +17,6 @@ from backend.sdk import (
)
from ._config import exa
from .helpers import merge_exa_cost
class AnswerCitation(BaseModel):
@@ -112,7 +111,3 @@ class ExaAnswerBlock(Block):
yield "citations", citations
for citation in citations:
yield "citation", citation
# Current SDK AnswerResponse dataclass omits cost_dollars; helper
# no-ops today, but keeps billing wired when exa_py adds the field.
merge_exa_cost(self, response)

View File

@@ -9,6 +9,7 @@ from typing import Union
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -22,7 +23,6 @@ from backend.sdk import (
)
from ._config import exa
from .helpers import merge_exa_cost
class CodeContextResponse(BaseModel):
@@ -118,5 +118,9 @@ class ExaCodeContextBlock(Block):
yield "search_time", context.search_time
yield "output_tokens", context.output_tokens
# API returns costDollars as a bare numeric string like "0.005".
merge_exa_cost(self, data)
# Parse cost_dollars (API returns as string, e.g. "0.005")
try:
cost_usd = float(context.cost_dollars)
self.merge_stats(NodeExecutionStats(provider_cost=cost_usd))
except (ValueError, TypeError):
pass

View File

@@ -4,6 +4,7 @@ from typing import Optional
from exa_py import AsyncExa
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -23,7 +24,6 @@ from .helpers import (
HighlightSettings,
LivecrawlTypes,
SummarySettings,
merge_exa_cost,
)
@@ -224,4 +224,6 @@ class ExaContentsBlock(Block):
if response.cost_dollars:
yield "cost_dollars", response.cost_dollars
merge_exa_cost(self, response)
self.merge_stats(
NodeExecutionStats(provider_cost=response.cost_dollars.total)
)

Some files were not shown because too many files have changed in this diff Show More