Compare commits

..

42 Commits

Author SHA1 Message Date
coderabbitai[bot]
3db2a944f7 fix: apply CodeRabbit auto-fixes
Fixed 1 file(s) based on 1 unresolved review comment.

Co-authored-by: CodeRabbit <noreply@coderabbit.ai>
2026-04-01 19:24:39 +00:00
Bentlybro
59192102a6 refactor(frontend): remove useMemo from useLibraryAgentList; fix useState staleness in useSitrepItems and useFleetSummary
- useLibraryAgentList: replace both useMemo calls with plain expressions;
  extract filterAgentsByStatus as a standalone helper function
- useSitrepItems: replace useState initializer with useMemo([agentIDs,
  maxItems]) so items recompute when the agent list changes
- useFleetSummary: replace useState initializer with useMemo([agentIDs])
  so fleet counts recompute when agents are added or removed
2026-04-01 19:10:03 +00:00
Bentlybro
65cca9bef8 fix(frontend): fix TS errors and AgentBriefingPanel regression
- Replace `variant="xsmall"` (not in Text component's type union) with
  `variant="small"` in PulseChips, StatsGrid, LibraryAgentCard, SitrepList
  — fixes the `check API types` CI failure
- useLibraryAgentList: expose `allAgentIDs` (unfiltered) and
  `displayedCount` (filtered count when filter is active)
- LibraryAgentList: pass `allAgentIDs` to AgentBriefingPanel so the
  sitrep covers the full fleet regardless of the active filter; pass
  `displayedCount` to the tab label so "All N" reflects the current view

Addresses Sentry comments 3023964054 and 3023964058.
2026-04-01 19:09:41 +00:00
Bentlybro
6b32e43d84 fix(frontend): wire statusFilter to agent list, fix consumePrompt, add NOTE comments
- useLibraryAgentList: accept statusFilter prop and apply client-side
  filtering using mockStatusForAgent; maps "attention"→health=attention,
  "healthy"→health=good, others match status directly
- LibraryAgentList: pass statusFilter through to useLibraryAgentList
- AutoPilotBridgeContext.consumePrompt: use `prompt !== null` instead of
  truthy check so empty string correctly clears sessionStorage
- Add NOTE comments near useState initialisers in useSitrepItems,
  useLibraryFleetSummary, useAgentStatus, and useFleetSummary explaining
  that they do not recompute on prop changes
2026-04-01 19:09:41 +00:00
Bentlybro
b73d05c23e fix(frontend): address review feedback — lint, format, and bug fixes
- Remove unused Button import in PulseChips (fixes lint CI failure)
- Fix AutoPilotBridgeContext: use Next.js router + sessionStorage instead
  of window.location.href which destroyed React state before consumption
- Add defensive handling in formatTimeAgo for invalid/future dates
- Use cn() utility in LibraryAgentCard for className consistency
- Fix prettier formatting across AgentFilterMenu, SitrepList, useAgentStatus
2026-03-31 17:22:23 +00:00
John Ababseh
8277cce835 feat(frontend): add Agent Intelligence Layer to library and home
Implements 7 new features for agent awareness across the platform:

1. Agent Briefing Panel — collapsible stats grid showing fleet-wide
   counts (running, error, listening, scheduled, idle, monthly spend)
2. Enhanced Library Cards — StatusBadge, progress bar, error messages,
   run counts, spend, and time-ago on every agent card
3. Expanded Filter & Sort — new AgentFilterMenu dropdown with status-
   based filtering (running, attention, listening, scheduled, idle)
4. AI Summary / Situation Report — prioritized SitrepList with error-
   first ranking and contextual action buttons
5. Ask AutoPilot Bridge — shared context that lets sitrep items and
   pulse chips pre-populate the Home chat with agent-specific prompts
6. Home Pulse Chips — lightweight agent status chips on the empty
   Home/Copilot screen linking back to the library
7. Contextual Action Buttons — status-aware actions (View error,
   Reconnect, Watch live, Run now, Retry) on cards and sitrep items

All features use deterministic mock data via useAgentStatus hook,
marked with TODO comments for backend API integration. Follows
existing component patterns (atoms/molecules/organisms), reuses
shadcn/ui primitives, Phosphor Icons, and platform design tokens.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 18:12:51 +02:00
Abhimanyu Yadav
57b17dc8e1 feat(platform): generic managed credential system with AgentMail auto-provisioning (#12537)
### Why / What / How

**Why:** We need a third credential type: **system-provided but unique
per user** (managed credentials). Currently we have system credentials
(same for all users) and user credentials (user provides their own
keys). Managed credentials bridge the gap — the platform provisions them
automatically, one per user, for integrations like AgentMail where each
user needs their own pod-scoped API key.

**What:**
- Generic **managed credential provider registry** — any integration can
register a provider that auto-provisions per-user credentials
- **AgentMail** is the first consumer: creates a pod + pod-scoped API
key using the org-level API key
- Managed credentials appear in the credential dropdown like normal API
keys but with `autogpt_managed=True` — users **cannot update or delete**
them
- **Auto-provisioning** on `GET /credentials` — lazily creates managed
credentials when users browse their credential list
- **Account deletion cleanup** utility — revokes external resources
(pods, API keys) before user deletion
- **Frontend UX** — hides the delete button for managed credentials on
the integrations page

**How:**

### Backend

**New files:**
- `backend/integrations/managed_credentials.py` —
`ManagedCredentialProvider` ABC, global registry,
`ensure_managed_credentials()` (with per-user asyncio lock +
`asyncio.gather` for concurrency), `cleanup_managed_credentials()`
- `backend/integrations/managed_providers/__init__.py` —
`register_all()` called at startup
- `backend/integrations/managed_providers/agentmail.py` —
`AgentMailManagedProvider` with `provision()` (creates pod + API key via
agentmail SDK) and `deprovision()` (deletes pod)

**Modified files:**
- `credentials_store.py` — `autogpt_managed` guards on update/delete,
`has_managed_credential()` / `add_managed_credential()` helpers
- `model.py` — `autogpt_managed: bool` + `metadata: dict` on
`_BaseCredentials`
- `router.py` — calls `ensure_managed_credentials()` in list endpoints,
removed explicit `/agentmail/connect` endpoint
- `user.py` — `cleanup_user_managed_credentials()` for account deletion
- `rest_api.py` — registers managed providers at startup
- `settings.py` — `agentmail_api_key` setting

### Frontend
- Added `autogpt_managed` to `CredentialsMetaResponse` type
- Conditionally hides delete button on integrations page for managed
credentials

### Key design decisions
- **Auto-provision in API layer, not data layer** — keeps
`get_all_creds()` side-effect-free
- **Race-safe** — per-(user, provider) asyncio lock with double-check
pattern prevents duplicate pods
- **Idempotent** — AgentMail SDK `client_id` ensures pod creation is
idempotent; `add_managed_credential()` uses upsert under Redis lock
- **Error-resilient** — provisioning failures are logged but never block
credential listing

### Changes 🏗️

| File | Action | Description |
|------|--------|-------------|
| `backend/integrations/managed_credentials.py` | NEW | ABC, registry,
ensure/cleanup |
| `backend/integrations/managed_providers/__init__.py` | NEW | Registers
all providers at startup |
| `backend/integrations/managed_providers/agentmail.py` | NEW |
AgentMail provisioning/deprovisioning |
| `backend/integrations/credentials_store.py` | MODIFY | Guards +
managed credential helpers |
| `backend/data/model.py` | MODIFY | `autogpt_managed` + `metadata`
fields |
| `backend/api/features/integrations/router.py` | MODIFY |
Auto-provision on list, removed `/agentmail/connect` |
| `backend/data/user.py` | MODIFY | Account deletion cleanup |
| `backend/api/rest_api.py` | MODIFY | Provider registration at startup
|
| `backend/util/settings.py` | MODIFY | `agentmail_api_key` setting |
| `frontend/.../integrations/page.tsx` | MODIFY | Hide delete for
managed creds |
| `frontend/.../types.ts` | MODIFY | `autogpt_managed` field |

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] 23 tests pass in `router_test.py` (9 new tests for
ensure/cleanup/auto-provisioning)
  - [x] `poetry run format && poetry run lint` — clean
  - [x] OpenAPI schema regenerated
- [x] Manual: verify managed credential appears in AgentMail block
dropdown
  - [x] Manual: verify delete button hidden for managed credentials
- [x] Manual: verify managed credential cannot be deleted via API (403)

#### For configuration changes:
- [x] `.env.default` is updated with `AGENTMAIL_API_KEY=`

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-03-31 12:56:18 +00:00
Krishna Chaitanya
a20188ae59 fix(blocks): validate non-empty input in AIConversationBlock before LLM call (#12545)
### Why / What / How

**Why:** When `AIConversationBlock` receives an empty messages list and
an empty prompt, the block blindly forwards the empty array to the
downstream LLM API, which returns a cryptic `400 Bad Request` error:
`"Invalid 'messages': empty array. Expected an array with minimum length
1."` This is confusing for users who don't understand why their agent
failed.

**What:** Add early input validation in `AIConversationBlock.run()` that
raises a clear `ValueError` when both `messages` and `prompt` are empty.
Also add three unit tests covering the validation logic.

**How:** A simple guard clause at the top of the `run` method checks `if
not input_data.messages and not input_data.prompt` before the LLM call
is made. If both are empty, a descriptive `ValueError` is raised. If
either one has content, the block proceeds normally.

### Changes

- `autogpt_platform/backend/backend/blocks/llm.py`: Add validation guard
in `AIConversationBlock.run()` to reject empty messages + empty prompt
before calling the LLM
- `autogpt_platform/backend/backend/blocks/test/test_llm.py`: Add
`TestAIConversationBlockValidation` with three tests:
- `test_empty_messages_and_empty_prompt_raises_error` — validates the
guard clause
- `test_empty_messages_with_prompt_succeeds` — ensures prompt-only usage
still works
- `test_nonempty_messages_with_empty_prompt_succeeds` — ensures
messages-only usage still works

### Checklist

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Lint passes (`ruff check`)
  - [x] Formatting passes (`ruff format`)
- [x] New unit tests validate the empty-input guard and the happy paths

Closes #11875

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-03-31 12:43:42 +00:00
goingforstudying-ctrl
c410be890e fix: add empty choices guard in extract_openai_tool_calls() (#12540)
## Summary

`extract_openai_tool_calls()` in `llm.py` crashes with `IndexError` when
the LLM provider returns a response with an empty `choices` list.

### Changes 🏗️

- Added a guard check `if not response.choices: return None` before
accessing `response.choices[0]`
- This is consistent with the function's existing pattern of returning
`None` when no tool calls are found

### Bug Details

When an LLM provider returns a response with an empty choices list
(e.g., due to content filtering, rate limiting, or API errors),
`response.choices[0]` raises `IndexError`. This can crash the entire
agent execution pipeline.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- Verified that the function returns `None` when `response.choices` is
empty
- Verified existing behavior is unchanged when `response.choices` is
non-empty

---------

Co-authored-by: goingforstudying-ctrl <forgithubuse@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-03-31 20:10:27 +07:00
Zamil Majdy
37d9863552 feat(platform): add extended thinking execution mode to OrchestratorBlock (#12512)
## Summary
- Adds `ExecutionMode` enum with `BUILT_IN` (default built-in tool-call
loop) and `EXTENDED_THINKING` (delegates to Claude Agent SDK for richer
reasoning)
- Extracts shared `tool_call_loop` into `backend/util/tool_call_loop.py`
— reusable by both OrchestratorBlock agent mode and copilot baseline
- Refactors copilot baseline to use the shared `tool_call_loop` with
callback-driven iteration

## ExecutionMode enum
`ExecutionMode` (`backend/blocks/orchestrator.py`) controls how
OrchestratorBlock executes tool calls:
- **`BUILT_IN`** — Default mode. Runs the built-in tool-call loop
(supports all LLM providers).
- **`EXTENDED_THINKING`** — Delegates to the Claude Agent SDK for
extended thinking and multi-step planning. Requires Anthropic-compatible
providers (`anthropic` / `open_router`) and direct API credentials
(subscription mode not supported). Validates both provider and model
name at runtime.

## Shared tool_call_loop
`backend/util/tool_call_loop.py` provides a generic, provider-agnostic
conversation loop:
1. Call LLM with tools → 2. Extract tool calls → 3. Execute tools → 4.
Update conversation → 5. Repeat

Callers provide three callbacks:
- `llm_call`: wraps any LLM provider (OpenAI streaming, Anthropic,
llm.llm_call, etc.)
- `execute_tool`: wraps any tool execution (TOOL_REGISTRY, graph block
execution, etc.)
- `update_conversation`: formats messages for the specific protocol

## OrchestratorBlock EXTENDED_THINKING mode
- `_create_graph_mcp_server()` converts graph-connected blocks to MCP
tools
- `_execute_tools_sdk_mode()` runs `ClaudeSDKClient` with those MCP
tools
- Agent mode refactored to use shared `tool_call_loop`

## Copilot baseline refactored
- Streaming callbacks buffer `Stream*` events during loop execution
- Events are drained after `tool_call_loop` returns
- Same conversation logic, less code duplication

## SDK environment builder extraction
- `build_sdk_env()` extracted to `backend/copilot/sdk/env.py` for reuse
by both copilot SDK service and OrchestratorBlock

## Provider validation
EXTENDED_THINKING mode validates `provider in ('anthropic',
'open_router')` and `model_name.startswith('claude')` because the Claude
Agent SDK requires an Anthropic API key or OpenRouter key. Subscription
mode is not supported — it uses the platform's internal credit system
which doesn't provide raw API keys needed by the SDK. The validation
raises a clear `ValueError` if an unsupported provider or model is used.

## PR Dependencies
This PR builds on #12511 (Claude SDK client). It can be reviewed
independently — #12511 only adds the SDK client module which this PR
imports. If #12511 merges first, this PR will have no conflicts.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] All pre-commit hooks pass (typecheck, lint, format)
  - [x] Existing OrchestratorBlock tests still pass
- [x] Copilot baseline behavior unchanged (same stream events, same tool
execution)
- [x] Manual: OrchestratorBlock with execution_mode=EXTENDED_THINKING +
downstream blocks → SDK calls tools
  - [x] Agent mode regression test (non-SDK path works as before)
  - [x] SDK mode error handling (invalid provider raises ValueError)
2026-03-31 20:04:13 +07:00
Krishna Chaitanya
2f42ff9b47 fix(blocks): validate email recipients in Gmail blocks before API call (#12546)
### Why / What / How

**Why:** When a user or LLM supplies a malformed recipient string (e.g.
a bare username, a JSON blob, or an empty value) to `GmailSendBlock`,
`GmailCreateDraftBlock`, or any reply block, the Gmail API returns an
opaque `HttpError 400: "Invalid To header"`. This surfaces as a
`BlockUnknownError` with no actionable guidance, making it impossible
for the LLM to self-correct. (Fixes #11954)

**What:** Adds a lightweight `validate_email_recipients()` function that
checks every recipient against a simplified RFC 5322 pattern
(`local@domain.tld`) and raises a clear `ValueError` listing all invalid
entries before any API call is made.

**How:** The validation is called in two shared code paths —
`create_mime_message()` (used by send and draft blocks) and
`_build_reply_message()` (used by reply blocks) — so all Gmail blocks
that compose outgoing email benefit from it with zero per-block changes.
The regex is intentionally permissive (any `x@y.z` passes) to avoid
false positives on unusual but valid addresses.

### Changes 🏗️

- Added `validate_email_recipients()` helper in `gmail.py` with a
compiled regex
- Hooked validation into `create_mime_message()` for `to`, `cc`, and
`bcc` fields
- Hooked validation into `_build_reply_message()` for reply/draft-reply
blocks
- Added `TestValidateEmailRecipients` test class covering valid,
invalid, mixed, empty, JSON-string, and field-name scenarios

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Verified `validate_email_recipients` correctly accepts valid
emails (`user@example.com`, `a@b.com`, `test@sub.domain.co`)
- [x] Verified it rejects malformed entries (bare names, missing domain
dot, empty strings, JSON strings)
- [x] Verified error messages include the field name and all invalid
entries
  - [x] Verified empty recipient lists pass without error
  - [x] Confirmed `gmail.py` and test file parse correctly (AST check)

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-03-31 12:37:33 +00:00
Zamil Majdy
914efc53e5 fix(backend): disambiguate duplicate tool names in OrchestratorBlock (#12555)
## Why
The OrchestratorBlock fails with `Tool names must be unique` when
multiple nodes use the same block type (e.g., two "Web Search" blocks
connected as tools). The Anthropic API rejects the request because
duplicate tool names are sent.

## What
- Detect duplicate tool names after building tool signatures
- Append `_1`, `_2`, etc. suffixes to disambiguate
- Enrich descriptions of duplicate tools with their hardcoded default
values so the LLM can distinguish between them
- Clean up internal `_hardcoded_defaults` metadata before sending to API
- Exclude sensitive/credential fields from default value descriptions

## How
- After `_create_tool_node_signatures` builds all tool functions, count
name occurrences
- For duplicates: rename with suffix and append `[Pre-configured:
key=value]` to description using the node's `input_default` (excluding
linked fields that the LLM provides)
- Added defensive `isinstance(defaults, dict)` check for compatibility
with test mocks
- Suffix collision avoidance: skips candidates that collide with
existing tool names
- Long tool names truncated to fit within 64-character API limit
- 47 unit tests covering: basic dedup, description enrichment, unique
names unchanged, no metadata leaks, single tool, triple duplicates,
linked field exclusion, mixed unique/duplicate scenarios, sensitive
field exclusion, long name truncation, suffix collision, malformed
tools, missing description, empty list, 10-tool all-same-name, multiple
distinct groups, large default truncation, suffix collision cascade,
parameter preservation, boundary name lengths, nested dict/list
defaults, null defaults, customized name priority, required fields

## Test plan
- [x] All 47 tests in `test_orchestrator_tool_dedup.py` pass
- [x] All 11 existing orchestrator unit tests pass (dict, dynamic
fields, responses API)
- [x] Pre-commit hooks pass (ruff, black, isort, pyright)
- [ ] Manual test: connect two same-type blocks to an orchestrator and
verify the LLM call succeeds

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-31 11:54:10 +00:00
Carson Kahn
17e78ca382 fix(docs): remove extraneous whitespace in README (#12587)
### Why / What / How

Remove extraneous whitespace in README.md:
- "Workflow Management" description: extra spaces between "block" and
"performs"
- "Agent Interaction" description: extra spaces between "user-friendly"
and "interface"

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-03-31 08:38:45 +00:00
Ubbe
7ba05366ed feat(platform/copilot): live timer stats with persisted duration (#12583)
## Why

The copilot chat had no indication of how long the AI spent "thinking"
on a response. Users couldn't tell if a long wait was normal or
something was stuck. Additionally, the thinking duration was lost on
page reload since it was only tracked client-side.

## What

- **Live elapsed timer**: Shows elapsed time ("23s", "1m 5s") in the
ThinkingIndicator while the AI is processing (appears after 20s to avoid
spam on quick responses)
- **Frozen "Thought for Xm Ys"**: Displays the final thinking duration
in TurnStatsBar after the response completes
- **Persisted duration**: Saves `durationMs` on the last assistant
message in the DB so the timer survives page reloads

## How

**Backend:**
- Added `durationMs Int?` column to `ChatMessage` (Prisma migration)
- `mark_session_completed` in `stream_registry.py` computes wall-clock
duration from Redis session `created_at` and saves it via
`DatabaseManager.set_turn_duration()`
- Invalidates Redis session cache after writing so GET returns fresh
data

**Frontend:**
- `useElapsedTimer` hook tracks client-side elapsed seconds during
streaming
- `ThinkingIndicator` shows only the elapsed time (no phrases) after
20s, with `font-mono text-sm` styling
- `TurnStatsBar` displays "Thought for Xs" after completion, preferring
live `elapsedSeconds` and falling back to persisted `durationMs`
- `convertChatSessionToUiMessages` extracts `duration_ms` from
historical messages into a `Map<string, number>` threaded through to
`ChatMessagesContainer`

## Test plan

- [ ] Send a message in copilot — verify ThinkingIndicator shows elapsed
time after 20s
- [ ] After response completes — verify "Thought for Xs" appears below
the response
- [ ] Refresh the page — verify "Thought for Xs" still appears
(persisted from DB)
- [ ] Check older conversations — they should NOT show timer (no
historical data)
- [ ] Verify no Zod/SSE validation errors in browser console

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-30 16:46:31 +07:00
Zamil Majdy
ca74f980c1 fix(copilot): resolve host-scoped credentials for authenticated web requests (#12579)
## Summary
- Fixed `_resolve_discriminated_credentials()` in `helpers.py` to handle
URL/host-based credential discrimination (used by
`SendAuthenticatedWebRequestBlock`)
- Previously, only provider-based discrimination (with
`discriminator_mapping`) was handled; URL-based discrimination (with
`discriminator` set but no `discriminator_mapping`) was silently skipped
- This caused host-scoped credentials to either match the wrong host or
fail to match at all when the CoPilot called `run_block` for
authenticated HTTP requests
- Added 14 targeted tests covering discriminator resolution, host
matching, credential resolution integration, and RunBlockTool end-to-end
flows

## Root Cause
`_resolve_discriminated_credentials()` checked `if
field_info.discriminator and field_info.discriminator_mapping:` which
excluded host-scoped credentials where `discriminator="url"` but
`discriminator_mapping=None`. The URL from `input_data` was never added
to `discriminator_values`, so `_credential_is_for_host()` received empty
`discriminator_values` and returned `True` for **any** host-scoped
credential regardless of URL match.

## Fix
When `discriminator` is set without `discriminator_mapping`, the URL
value from `input_data` is now copied into `discriminator_values` on a
shallow copy of the field info (to avoid mutating the cached schema).
This enables `_credential_is_for_host()` to properly match the
credential's host against the target URL.

## Test plan
- [x] `TestResolveDiscriminatedCredentials` - 4 tests verifying URL
discriminator populates values, handles missing URL, doesn't mutate
original, preserves provider/type
- [x] `TestFindMatchingHostScopedCredential` - 5 tests verifying
correct/wrong host matching, wildcard hosts, multiple credential
selection
- [x] `TestResolveBlockCredentials` - 3 integration tests verifying full
credential resolution with matching/wrong/missing hosts
- [x] `TestRunBlockToolAuthenticatedHttp` - 2 end-to-end tests verifying
SetupRequirementsResponse when creds missing and BlockDetailsResponse
when creds matched
- [x] All 28 existing + new tests pass
- [x] Ruff lint, isort, Black formatting, pyright typecheck all pass

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-28 08:12:33 +00:00
Zamil Majdy
68f5d2ad08 fix(blocks): raise AIConditionBlock errors instead of swallowing them (#12593)
## Why

Sentry alert
[AUTOGPT-SERVER-8C8](https://significant-gravitas.sentry.io/issues/7367978095/)
— `AIConditionBlock` failing in prod with:

```
Invalid 'max_output_tokens': integer below minimum value.
Expected a value >= 16, but got 10 instead.
```

Two problems:
1. `max_tokens=10` is below OpenAI's new minimum of 16
2. The `except Exception` handler was calling `logger.error()` which
triggered Sentry for what are known block errors, AND silently
defaulting to `result=False` — making the block appear to succeed with
an incorrect answer

## What

- Bump `max_tokens` from 10 to 16 (fixes the root cause)
- Remove the `try/except` entirely — the executor already handles
exceptions correctly (`ValueError` = known/no Sentry, everything else =
unknown/Sentry). The old handler was just swallowing errors and
producing wrong results.

## Test plan

- [x] Existing `AIConditionBlock` tests pass (block only expects
"true"/"false", 16 tokens is plenty)
- [x] No more silent `result=False` on errors
- [x] No more spurious Sentry alerts from `logger.error()`

Fixes AUTOGPT-SERVER-8C8

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 10:28:14 +00:00
Nicholas Tindle
2b3d730ca9 dx(skills): add /open-pr and /setup-repo skills (#12591)
### Why / What / How

**Why:** Agents working in worktrees lack guidance on two of the most
common workflows: properly opening PRs (using the repo template,
validating test coverage, triggering the review bot) and bootstrapping
the repo from scratch with a worktree-based layout. Without these
skills, agents either skip steps (no test plan, wrong template) or
require manual hand-holding for setup.

**What:** Adds two new Claude Code skills under `.claude/skills/`:
- `/open-pr` — A structured PR creation workflow that enforces the
canonical `.github/PULL_REQUEST_TEMPLATE.md`, validates test coverage
for existing and new behaviors, supports a configurable base branch, and
integrates the `/review` bot workflow for agents without local testing
capability. Cross-references `/pr-test`, `/pr-review`, and `/pr-address`
for the full PR lifecycle.
- `/setup-repo` — An interactive repo bootstrapping skill that creates a
worktree-based layout (main + reviews + N numbered work branches).
Handles .env file provisioning with graceful fallbacks (.env.default,
.env.example), copies branchlet config, installs dependencies, and is
fully idempotent (safe to re-run).

**How:** Markdown-based SKILL.md files following the existing skill
conventions. Both skills use proper bash patterns (seq-based loops
instead of brace expansion with variables, existence checks before
branch/worktree creation, error reporting on install failures).
`/open-pr` delegates to AskUserQuestion-style prompts for base branch
selection. `/setup-repo` uses AskUserQuestion for interactive branch
count and base branch selection.

### Changes 🏗️

- Added `.claude/skills/open-pr/SKILL.md` — PR creation workflow with:
  - Pre-flight checks (committed, pushed, formatted)
- Test coverage validation (existing behavior not broken, new behavior
covered)
- Canonical PR template enforcement (read and fill verbatim, no
pre-checked boxes)
  - Configurable base branch (defaults to dev)
- Review bot workflow (`/review` comment + 30min wait) for agents
without local testing
  - Related skills table linking `/pr-test`, `/pr-review`, `/pr-address`

- Added `.claude/skills/setup-repo/SKILL.md` — Repo bootstrap workflow
with:
- Interactive setup (branch count: 4/8/16/custom, base branch selection)
- Idempotent branch creation (skips existing branches with info message)
  - Idempotent worktree creation (skips existing directories)
- .env provisioning with fallback chain (.env → .env.default →
.env.example → warning)
  - Branchlet config propagation
  - Dependency installation with success/failure reporting per worktree

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verified SKILL.md frontmatter follows existing skill conventions
  - [x] Verified trigger conditions match expected user intents
  - [x] Verified cross-references to existing skills are accurate
- [x] Verified PR template section matches
`.github/PULL_REQUEST_TEMPLATE.md`
- [x] Verified bash snippets use correct patterns (seq, show-ref, quoted
vars)
  - [x] Pre-commit hooks pass on all commits
  - [x] Addressed all CodeRabbit, Sentry, and Cursor review comments

🤖 Generated with [Claude Code](https://claude.com/claude-code)

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Low Risk**
> Low risk documentation-only change: adds new markdown skills without
modifying runtime code. Main risk is workflow guidance drift (e.g.,
`.env`/worktree steps) if it diverges from actual repo conventions.
> 
> **Overview**
> Adds two new Claude Code skills under `.claude/skills/` to standardize
common developer workflows.
> 
> `/open-pr` documents a PR creation flow that enforces using
`.github/PULL_REQUEST_TEMPLATE.md` verbatim, calls out required test
coverage, and describes how to trigger/poll the `/review` bot when local
testing isn’t available.
> 
> `/setup-repo` documents an idempotent, interactive bootstrap for a
multi-worktree layout (creates `reviews` and `branch1..N`, provisions
`.env` files with `.env.default`/`.env.example` fallbacks, copies
`.branchlet.json`, and installs dependencies), complementing the
existing `/worktree` skill.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
80dbeb1596. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
2026-03-27 10:22:03 +00:00
Zamil Majdy
f28628e34b fix(backend): preserve thinking blocks during transcript compaction (#12574)
## Why

AutoPilot users hit `invalid_request_error` ("thinking or
redacted_thinking blocks in the latest assistant message cannot be
modified") when sessions get long enough to trigger transcript
compaction. The Anthropic API requires thinking blocks in the last
assistant message to be byte-for-byte identical to the original response
— our compaction was flattening them to plain text, destroying the
cryptographic signatures.

Reported in Discord `#breakage` by John Ababseh with session
`31d3f08a-cb94-45eb-9fce-56b3f0287ef4`.

## What

- **`compact_transcript`** now splits the transcript into a compressible
prefix and a preserved tail (last assistant entry + trailing entries).
Only the prefix is compressed; the tail is re-appended verbatim,
preserving thinking blocks exactly.
- **`_flatten_assistant_content`** now silently drops `thinking` and
`redacted_thinking` blocks instead of creating `[__thinking__]`
placeholders — they carry no useful context for compression summaries.
- **`response_adapter`** explicitly handles `ThinkingBlock` (skip
gracefully instead of silently falling through the isinstance chain).
- **`_format_sdk_content_blocks`** now passes through raw dict blocks
(e.g. `redacted_thinking` that the SDK may not have a typed class for)
verbatim to the transcript.

## How

The key insight is the Anthropic API's asymmetric constraint:
- **Last assistant message**: thinking/redacted_thinking blocks must be
preserved byte-for-byte
- **Older assistant messages**: thinking blocks can be removed entirely

`compact_transcript` uses `_find_last_assistant_entry()` to split the
JSONL into two parts:
1. **Prefix** (everything before the last assistant): flattened and
compressed normally
2. **Tail** (last assistant + any trailing user message): preserved
verbatim and re-chained via `_rechain_tail()` to maintain the
`parentUuid` chain

This ensures the API always sees the original thinking blocks in the
last assistant message while still achieving meaningful compression on
older turns.

## Test plan
- [x] 25 new tests across `thinking_blocks_test.py` (TDD: written before
implementation)
- [x] `_find_last_assistant_entry` splits correctly at last assistant,
handles edges (no assistant, index 0, trailing user)
  - [x] `_rechain_tail` patches parentUuid chain, handles empty tail
- [x] `_flatten_assistant_content` strips thinking/redacted_thinking
blocks, handles mixed content
  - [x] `compact_transcript` preserves last assistant's thinking blocks
- [x] `compact_transcript` strips thinking from older assistant messages
- [x] Edge cases: trailing user message, single assistant, no thinking
blocks
  - [x] `response_adapter` handles ThinkingBlock without crash
- [x] `_format_sdk_content_blocks` preserves thinking block format and
raw dict blocks
- [x] All existing copilot SDK tests pass
- [x] Pre-commit hooks (lint, format, typecheck) all pass
2026-03-27 06:36:52 +00:00
Zamil Majdy
b6a027fd2b fix(platform): fix prod Sentry errors and reduce on-call alert noise (#12565)
## Why

Multiple Sentry issues paging on-call in prod:

1. **AUTOGPT-SERVER-8BP**: `ConversionError: Failed to convert
anthropic/claude-sonnet-4-6 to <enum 'LlmModel'>` — the copilot passes
OpenRouter-style provider-prefixed model names
(`anthropic/claude-sonnet-4-6`) to blocks, but the `LlmModel` enum only
recognizes the bare model ID (`claude-sonnet-4-6`).

2. **BUILDER-7GF**: `Error invoking postEvent: Method not found` —
Sentry SDK internal error on Chrome Mobile Android, not a platform bug.

3. **XMLParserBlock**: `BlockUnknownError raised by XMLParserBlock with
message: Error in input xml syntax` — user sent bad XML but the block
raised `SyntaxError`, which gets wrapped as `BlockUnknownError`
(unexpected) instead of `BlockExecutionError` (expected).

4. **AUTOGPT-SERVER-8BS**: `Virus scanning failed for Screenshot
2026-03-26 091900.png: range() arg 3 must not be zero` — empty (0-byte)
file upload causes `range(0, 0, 0)` in the virus scanner chunking loop,
and the failure is logged at `error` level which pages on-call.

5. **AUTOGPT-SERVER-8BT**: `ValueError: <Token var=<ContextVar
name='current_context'>> was created in a different Context` —
OpenTelemetry `context.detach()` fails when the SDK streaming async
generator is garbage-collected in a different context than where it was
created (client disconnect mid-stream).

6. **AUTOGPT-SERVER-8BW**: `RuntimeError: Attempted to exit cancel scope
in a different task than it was entered in` — anyio's
`TaskGroup.__aexit__` detects cancel scope entered in one task but
exited in another when `GeneratorExit` interrupts the SDK cleanup during
client disconnect.

7. **Workspace UniqueViolationError**: `UniqueViolationError: Unique
constraint failed on (workspaceId, path)` — race condition during
concurrent file uploads handled by `WorkspaceManager._persist_db_record`
retry logic, but Sentry still captures the exception at the raise site.

8. **Library UniqueViolationError**: `UniqueViolationError` on
`LibraryAgent (userId, agentGraphId, agentGraphVersion)` — race
conditions in `add_graph_to_library` and `create_library_agent` caused
crashes or silent data loss.

9. **Graph version collision**: `UniqueViolationError` on `AgentGraph
(id, version)` — copilot re-saving an agent at an existing version
collides with the primary key.

## What

### Backend: `LlmModel._missing_()` for provider-prefixed model names
- Adds `_missing_` classmethod to `LlmModel` enum that strips the
provider prefix (e.g., `anthropic/`) when direct lookup fails
- Self-contained in the enum — no changes to the generic type conversion
system

### Frontend: Filter Sentry SDK noise
- Adds `postEvent: Method not found` to `ignoreErrors` — a known Sentry
SDK issue on certain mobile browsers

### Backend: XMLParserBlock — raise ValueError instead of SyntaxError
- Changed `_validate_tokens()` to raise `ValueError` instead of
`SyntaxError`
- Changed the `except SyntaxError` handler in `run()` to re-raise as
`ValueError`
- This ensures `Block.execute()` wraps XML parsing failures as
`BlockExecutionError` (expected/user-caused) instead of
`BlockUnknownError` (unexpected/alerts Sentry)

### Backend: Virus scanner — handle empty files + reduce alert noise
- Added early return for empty (0-byte) files in `scan_file()` to avoid
`range() arg 3 must not be zero` when `chunk_size` is 0
- Added `max(1, len(content))` guard on `chunk_size` as defense-in-depth
- Downgraded `scan_content_safe` failure log from `error` to `warning`
so single-file scan failures don't page on-call via Sentry

### Backend: Suppress SDK client cleanup errors on SSE disconnect
- Replaced `async with ClaudeSDKClient` in `_run_stream_attempt` with
manual `__aenter__`/`__aexit__` wrapped in new
`_safe_close_sdk_client()` helper
- `_safe_close_sdk_client()` catches `ValueError` (OTEL context token
mismatch) and `RuntimeError` (anyio cancel scope in wrong task) during
`__aexit__` and logs at `debug` level — these are expected when SSE
client disconnects mid-stream
- Added `_is_sdk_disconnect_error()` helper for defense-in-depth at the
outer `except BaseException` handler in `stream_chat_completion_sdk`
- Both Sentry errors (8BT and 8BW) are now suppressed without affecting
normal cleanup flow

### Backend: Filter workspace UniqueViolationError from Sentry alerts
- Added `before_send` filter in `_before_send()` to drop
`UniqueViolationError` events where the message contains `workspaceId`
and `path`
- The error is already handled by `WorkspaceManager._persist_db_record`
retry logic — it must propagate for the retry logic to work, so the fix
is at the Sentry filter level rather than catching/suppressing at source

### Backend: Library agent race condition fixes
- **`add_graph_to_library`**: Replaced check-then-create pattern with
create-then-catch-`UniqueViolationError`-then-update. On collision,
updates the existing row (restoring soft-deleted/archived agents)
instead of crashing.
- **`create_library_agent`**: Replaced `create` with `upsert` on the
`(userId, agentGraphId, agentGraphVersion)` composite unique constraint,
so concurrent adds restore soft-deleted entries instead of throwing.

### Backend: Graph version auto-increment on collision
- `__create_graph` now checks if the `(id, version)` already exists
before `create_many`, and auto-increments the version to `max_existing +
1` to avoid `UniqueViolationError` when the copilot re-saves an agent.

### Backend: Workspace `get_or_create_workspace` upsert
- Changed from find-then-create to `upsert` to atomically handle
concurrent workspace creation.

## Test plan

- [x] `LlmModel("anthropic/claude-sonnet-4-6")` resolves correctly
- [x] `LlmModel("claude-sonnet-4-6")` still works (no regression)
- [x] `LlmModel("invalid/nonexistent-model")` still raises `ValueError`
- [x] XMLParserBlock: unclosed tags, extra closing tags, empty XML all
raise `ValueError`
- [x] XMLParserBlock: `SyntaxError` from gravitasml library is caught
and re-raised as `ValueError`
- [x] Virus scanner: empty file (0 bytes) returns clean without hitting
ClamAV
- [x] Virus scanner: single-byte file scans normally (regression test)
- [x] Virus scanner: `scan_content_safe` logs at WARNING not ERROR on
failure
- [x] SDK disconnect: `_is_sdk_disconnect_error` correctly identifies
cancel scope and context var errors
- [x] SDK disconnect: `_is_sdk_disconnect_error` rejects unrelated
errors
- [x] SDK disconnect: `_safe_close_sdk_client` suppresses ValueError,
RuntimeError, and unexpected exceptions
- [x] SDK disconnect: `_safe_close_sdk_client` calls `__aexit__` on
clean exit
- [x] Library: `add_graph_to_library` creates new agent on first call
- [x] Library: `add_graph_to_library` updates existing on
UniqueViolationError
- [x] Library: `create_library_agent` uses upsert to handle concurrent
adds
- [x] All existing workspace overwrite tests still pass
- [x] All tests passing (existing + 4 XML syntax + 3 virus scanner + 10
SDK disconnect + library tests)
2026-03-27 06:09:42 +00:00
Zamil Majdy
fb74fcf4a4 feat(platform): add shared admin user search + rate-limit modal on spending page (#12577)
## Why
Admin rate-limit management required manually entering user UUIDs. The
spending page already had user search but it wasn't reusable.

## What
- Extract `AdminUserSearch` as shared component from spending page
search
- Add rate-limit modal (usage bars + reset) to spending page user rows
- Add email/name/UUID search to standalone rate-limits page
- Backend: add email query parameter to rate-limit endpoint

## How
- `AdminUserSearch` in `admin/components/` — reused by both spending and
rate-limits
- `RateLimitModal` opens from spending page "Rate Limits" button
- Backend `_resolve_user_id()` accepts email or user_id
- Smart routing: exact email → direct lookup, UUID → direct, partial →
fuzzy search

### Follow-up
- `AdminUserSearch` is a plain text input with no typeahead/fuzzy
suggestions — consider adding autocomplete dropdown with debounced
search

### Checklist 📋
- [x] Shared search component extracted and reused
- [x] Tests pass
- [x] Type-checked
2026-03-27 05:53:04 +00:00
Zamil Majdy
28b26dde94 feat(platform): spend credits to reset CoPilot daily rate limit (#12526)
## Summary
- When users hit their daily CoPilot token limit, they can now spend
credits ($2.00 default) to reset it and continue working
- Adds a dialog prompt when rate limit error occurs, offering the
credit-based reset option
- Adds a "Reset daily limit" button in the usage limits panel when the
daily limit is reached
- Backend: new `POST /api/chat/usage/reset` endpoint,
`reset_daily_usage()` Redis helper, `rate_limit_reset_cost` config
- Frontend: `RateLimitResetDialog` component, updated
`UsagePanelContent` with reset button, `useCopilotStream` exposes rate
limit state
- **NEW: Resetting the daily limit also reduces weekly usage by the
daily limit amount**, effectively granting 1 extra day's worth of weekly
capacity (e.g., daily_limit=10000 → weekly usage reduced by 10000,
clamped to 0)

## Context
Users have been confused about having credits available but being
blocked by rate limits (REQ-63, REQ-61). This provides a short-term
solution allowing users to spend credits to bypass their daily limit.

The weekly usage reduction ensures that a paid daily reset doesn't just
move the bottleneck to the weekly limit — users get genuine additional
capacity for the day they paid to unlock.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Hit daily rate limit → dialog appears with reset option
- [x] Click "Reset for $2.00" → credits charged, daily counter reset,
dialog closes
- [x] Usage panel shows "Reset daily limit" button when at 100% daily
usage
- [x] When `rate_limit_reset_cost=0` (disabled), rate limit shows toast
instead of dialog
  - [x] Insufficient credits → error toast shown
  - [x] Verify existing rate limit tests pass
  - [x] Unit tests: weekly counter reduced by daily_limit on reset
  - [x] Unit tests: weekly counter clamped to 0 when usage < daily_limit
  - [x] Unit tests: no weekly reduction when daily_token_limit=0

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
(new config fields `rate_limit_reset_cost` and `max_daily_resets` have
defaults in code)
- [x] `docker-compose.yml` is updated or already compatible with my
changes (no Docker changes needed)
2026-03-26 13:52:08 +00:00
Zamil Majdy
d677978c90 feat(platform): admin rate limit check and reset with LD-configurable global limits (#12566)
## Why
Admins need visibility into per-user CoPilot rate limit usage and the
ability to reset a user's counters when needed (e.g., after a false
positive or for debugging). Additionally, the global rate limits were
hardcoded deploy-time constants with no way to adjust without
redeploying.

## What
- Admin endpoints to **check** a user's current rate limit usage and
**reset** their daily/weekly counters to zero
- Global rate limits are now **LaunchDarkly-configurable** via
`copilot-daily-token-limit` and `copilot-weekly-token-limit` flags,
falling back to existing `ChatConfig` values
- Frontend admin page at `/admin/rate-limits` with user lookup, usage
visualization, and reset capability
- Chat routes updated to source global limits from LD flags

## How
- **Backend**: Added `reset_user_usage()` to `rate_limit.py` that
deletes Redis usage keys. New admin routes in
`rate_limit_admin_routes.py` (GET `/api/copilot/admin/rate_limit` and
POST `/api/copilot/admin/rate_limit/reset`). Added
`COPILOT_DAILY_TOKEN_LIMIT` and `COPILOT_WEEKLY_TOKEN_LIMIT` to the
`Flag` enum. Chat routes use `_get_global_rate_limits()` helper that
checks LD first.
- **Frontend**: New `/admin/rate-limits` page with `RateLimitManager`
(user lookup) and `RateLimitDisplay` (usage bars + reset button). Added
`getUserRateLimit` and `resetUserRateLimit` to `BackendAPI` client.

## Test plan
- [x] Backend: 4 tests covering get, reset, redis failure, and
admin-only access
- [ ] Manual: Look up a user's rate limits in the admin UI
- [ ] Manual: Reset a user's usage counters
- [ ] Manual: Verify LD flag overrides are respected for global limits
2026-03-26 08:29:40 +00:00
Otto
a347c274b7 fix(frontend): replace unrealistic CoPilot suggestion prompt (#12564)
Replaces "Sort my bookmarks into categories" with "Summarize my unread
emails" in the Organize suggestion category. CoPilot has no access to
browser bookmarks or local files, so the original prompt was misleading.

---
Co-authored-by: Toran Bruce Richards (@Torantulino)
<Torantulino@users.noreply.github.com>
2026-03-26 08:10:28 +00:00
Zamil Majdy
f79d8f0449 fix(backend): move placeholder_values exclusively to AgentDropdownInputBlock (#12551)
## Why

`AgentInputBlock` has a `placeholder_values` field whose
`generate_schema()` converts it into a JSON schema `enum`. The frontend
renders any field with `enum` as a dropdown/select. This means
AI-generated agents that populate `placeholder_values` with example
values (e.g. URLs) on regular `AgentInputBlock` nodes end up with
dropdowns instead of free-text inputs — users can't type custom values.

Only `AgentDropdownInputBlock` should produce dropdown behavior.

## What

- Removed `placeholder_values` field from `AgentInputBlock.Input`
- Moved the `enum` generation logic to
`AgentDropdownInputBlock.Input.generate_schema()`
- Cleaned up test data for non-dropdown input blocks
- Updated copilot agent generation guide to stop suggesting
`placeholder_values` for `AgentInputBlock`

## How

The base `AgentInputBlock.Input.generate_schema()` no longer converts
`placeholder_values` → `enum`. Only `AgentDropdownInputBlock.Input`
defines `placeholder_values` and overrides `generate_schema()` to
produce the `enum`.

**Backward compatibility**: Existing agents with `placeholder_values` on
`AgentInputBlock` nodes load fine — `model_construct()` silently ignores
extra fields not defined on the model. Those inputs will now render as
text fields (desired behavior).

## Test plan
- [x] `poetry run pytest backend/blocks/test/test_block.py -xvs` — all
block tests pass
- [x] `poetry run format && poetry run lint` — clean
- [ ] Import an agent JSON with `placeholder_values` on an
`AgentInputBlock` — verify it loads and renders as text input
- [ ] Create an agent with `AgentDropdownInputBlock` — verify dropdown
still works
2026-03-26 08:09:38 +00:00
Otto
1bc48c55d5 feat(copilot): add copy button to user prompt messages [SECRT-2172] (#12571)
Requested by @itsababseh

Users can copy assistant output messages but not their own prompts. This
adds the same copy button to user messages — appears on hover,
right-aligned, using the existing `CopyButton` component.

## Why

Users write long prompts and need to copy them to reuse or share.
Currently requires manual text selection. ChatGPT shows copy on hover
for user messages — this matches that pattern.

## What

- Added `CopyButton` to user prompt messages in
`ChatMessagesContainer.tsx`
- Shows on hover (`group-hover:opacity-100`), positioned right-aligned
below the message
- Reuses the existing `CopyButton` and `MessageActions` components —
zero new code

## How

One file changed, 11 lines added:
1. Import `MessageActions` and `CopyButton`
2. Render them after user `MessageContent`, gated on `message.role ===
"user"` and having text parts

---
Co-authored-by: itsababseh (@itsababseh)
<36419647+itsababseh@users.noreply.github.com>
2026-03-26 08:02:28 +00:00
Abhimanyu Yadav
9d0a31c0f1 fix(frontend/builder): fix array field item layout and add FormRenderer stories (#12532)
Fix broken UI when selecting nodes with array fields (list[str],
list[Enum]) in the builder. The select/input inside array items was
squeezed by the Remove button instead of taking full width.
<img width="2559" height="1077" alt="Screenshot 2026-03-26 at 10 23
34 AM"
src="https://github.com/user-attachments/assets/2ffc28a2-8d6c-428c-897c-021b1575723c"
/>

### Changes 🏗️

- **ArrayFieldItemTemplate**: Changed layout from horizontal flex-row to
vertical flex-col so the input takes full width and Remove button sits
below aligned left, with tighter spacing between them
- **Storybook config**: Added `renderers/**` glob to
`.storybook/main.ts` so renderer stories are discoverable
- **FormRenderer stories**: Added comprehensive Storybook stories
covering all backend field types (string, int, float, bool, enum,
date/time, list[str], list[int], list[Enum], list[bool], nested objects,
Optional, anyOf unions, oneOf discriminated unions, multi-select, list
of objects, and a kitchen sink). Includes exact Twitter GetUserBlock
schema for realistic oneOf + multi-select testing.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Verified array field items render with full-width input and Remove
button below in Storybook
  - [x] Verified list[Enum] select dropdown takes full width
  - [x] Verified list[str] text input takes full width
- [x] Verified all FormRenderer stories render without errors in
Storybook
- [x] Verified multi-select and oneOf discriminated union stories match
real backend schemas
2026-03-26 06:15:30 +00:00
Abhimanyu Yadav
9b086e39c6 fix(frontend): hide placeholder text when copilot voice recording is active (#12534)
### Why / What / How

**Why:** When voice recording is active in the CoPilot chat input, the
recording UI (waveform + timer) overlays on top of the placeholder/hint
text, creating a visually broken appearance. Reported by a user via
SECRT-2163.

**What:** Hide the textarea placeholder text while voice recording is
active so it doesn't bleed through the `RecordingIndicator` overlay.

**How:** When `isRecording` is true, the placeholder is set to an empty
string. The existing `RecordingIndicator` overlay (waveform animation +
elapsed time) then displays cleanly without the hint text showing
underneath.

### Changes 🏗️

- Clear the `PromptInputTextarea` placeholder to `""` when voice
recording is active, preventing it from rendering behind the
`RecordingIndicator` overlay

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Open CoPilot chat at /copilot
- [x] Click the microphone button or press Space to start voice
recording
- [x] Verify the placeholder text ("Type your message..." / "What else
can I help with?") is hidden during recording
- [x] Verify the RecordingIndicator (waveform + timer) displays cleanly
without overlapping text
  - [x] Stop recording and verify placeholder text reappears
  - [x] Verify "Transcribing..." placeholder shows during transcription
2026-03-26 05:41:09 +00:00
Zamil Majdy
5867e4d613 Merge branch 'master' of github.com:Significant-Gravitas/AutoGPT into dev 2026-03-26 07:30:56 +07:00
Zamil Majdy
85f0d8353a fix(platform): fix prod Sentry errors and reduce on-call alert noise (#12560)
## Summary
Hotfix targeting master for production Sentry errors that are triggering
on-call pages. Fixes actual bugs and expands Sentry filters to suppress
user-caused errors that are not platform issues.

### Bug Fixes
- **Workspace race condition** (`get_or_create_workspace`): Replaced
Prisma's non-atomic `upsert` with find-then-create pattern. Prisma's
upsert translates to SELECT + INSERT (not PostgreSQL's native `INSERT
... ON CONFLICT`), causing `UniqueViolationError` when concurrent
requests hit for the same user (e.g. copilot + file upload
simultaneously).
- **ChatSidebar crash**: Added null-safe `?.` for `sessions` which can
be `undefined` during error/loading states, preventing `TypeError:
Cannot read properties of undefined (reading 'length')`.
- **UsageLimits crash**: Added null-safe `?.` for
`usage.daily`/`usage.weekly` which can be `undefined` when the API
returns partial data, preventing `TypeError: Cannot read properties of
undefined (reading 'limit')`.

### Sentry Filter Improvements
Expanded backend `_before_send` to stop user-caused errors from reaching
Sentry and triggering on-call alerts:
- **Consolidated auth keywords** into a shared `_USER_AUTH_KEYWORDS`
list used by both exception-based and log-based filters (previously
duplicated).
- **Added missing auth keywords**: `"unauthorized"`, `"bad
credentials"`, `"insufficient authentication scopes"` — these were
leaking through.
- **Added user integration HTTP error filter**: `"http 401 error"`,
`"http 403 error"`, `"http 404 error"` — catches `BlockUnknownError` and
`HTTPClientError` from user integrations (expired GitHub tokens, wrong
Airtable IDs, etc.).
- **Fixed log-based event gap**: User auth errors logged via
`logger.error()` (not raised as exceptions) were bypassing the
`exc_info` filter. Now the same `_USER_AUTH_KEYWORDS` list is checked
against log messages too.

## On-Call Alerts Addressed

### Fixed (actual bugs)
| Alert | Issue | Root Cause |
|-------|-------|------------|
| `Unique constraint failed on the fields: (userId)` |
[AUTOGPT-SERVER-8BM](https://significant-gravitas.sentry.io/issues/AUTOGPT-SERVER-8BM)
| Prisma upsert race condition |
| `Unique constraint failed on the fields: (userId)` |
[AUTOGPT-SERVER-8BK](https://significant-gravitas.sentry.io/issues/AUTOGPT-SERVER-8BK)
| Same — via `/api/workspace/files/upload` |
| `Unique constraint failed on the fields: (userId)` |
[AUTOGPT-SERVER-8BN](https://significant-gravitas.sentry.io/issues/AUTOGPT-SERVER-8BN)
| Same — via `tools/call run_block` |
| `Upload failed (500): Unique constraint failed` |
[BUILDER-7GA](https://significant-gravitas.sentry.io/issues/BUILDER-7GA)
| Frontend surface of same workspace bug |
| `Cannot read properties of undefined (reading 'length')` |
[BUILDER-7GD](https://significant-gravitas.sentry.io/issues/BUILDER-7GD)
| `sessions` undefined in ChatSidebar |
| `Cannot read properties of undefined (reading 'limit')` |
[BUILDER-7GB](https://significant-gravitas.sentry.io/issues/BUILDER-7GB)
| `usage.daily` undefined in UsageLimits |

### Filtered (user-caused, not platform bugs)
| Alert | Issue | Why it's not a platform bug |
|-------|-------|-----------------------------|
| `Anthropic API error: invalid x-api-key` |
[AUTOGPT-SERVER-8B6](https://significant-gravitas.sentry.io/issues/AUTOGPT-SERVER-8B6),
8B7, 8B8 | User provided invalid Anthropic API key |
| `AI condition evaluation failed: Incorrect API key` |
[AUTOGPT-SERVER-83Y](https://significant-gravitas.sentry.io/issues/AUTOGPT-SERVER-83Y)
| User's OpenAI key is wrong (4.5K events, 1 user) |
| `GithubListIssuesBlock: HTTP 401 Bad credentials` |
[AUTOGPT-SERVER-8BF](https://significant-gravitas.sentry.io/issues/AUTOGPT-SERVER-8BF)
| User's GitHub token expired |
| `HTTPClientError: HTTP 401 Unauthorized` |
[AUTOGPT-SERVER-8BG](https://significant-gravitas.sentry.io/issues/AUTOGPT-SERVER-8BG)
| Same — credential check endpoint |
| `GithubReadIssueBlock: HTTP 401 Bad credentials` |
[AUTOGPT-SERVER-8BH](https://significant-gravitas.sentry.io/issues/AUTOGPT-SERVER-8BH)
| Same — different block |
| `AirtableCreateBaseBlock: HTTP 404 MODEL_ID_NOT_FOUND` |
[AUTOGPT-SERVER-8BC](https://significant-gravitas.sentry.io/issues/AUTOGPT-SERVER-8BC)
| User's Airtable model ID is wrong |

### Not addressed in this PR
| Alert | Issue | Reason |
|-------|-------|--------|
| `Unexpected token '<', "<html><hea"...` |
[BUILDER-7GC](https://significant-gravitas.sentry.io/issues/BUILDER-7GC)
| Transient — backend briefly returned HTML error page |
| `undefined is not an object (activeResponse.state)` |
[BUILDER-71J](https://significant-gravitas.sentry.io/issues/BUILDER-71J)
| Bug in Vercel AI SDK `ai@6.0.59`, already resolved |
| `Last Tool Output is needed` |
[AUTOGPT-SERVER-72T](https://significant-gravitas.sentry.io/issues/AUTOGPT-SERVER-72T)
| User graph misconfiguration (1 user, 21 events) |
| `Cannot set property ethereum` |
[BUILDER-7G6](https://significant-gravitas.sentry.io/issues/BUILDER-7G6)
| Browser wallet extension conflict |
| `File already exists at path` |
[BUILDER-7FS](https://significant-gravitas.sentry.io/issues/BUILDER-7FS)
| Expected 409 conflict |

## Test plan
- [ ] Verify workspace creation works for new users
- [ ] Verify concurrent workspace access (e.g. copilot + file upload)
doesn't error
- [ ] Verify copilot ChatSidebar and UsageLimits load correctly when API
returns partial/error data
- [ ] Verify user auth errors (invalid API keys, expired tokens) no
longer appear in Sentry after deployment
2026-03-25 23:25:32 +07:00
An Vy Le
f871717f68 fix(backend): add sink input validation to AgentValidator (#12514)
## Summary

- Added `validate_sink_input_existence` method to `AgentValidator` to
ensure all sink names in links and input defaults reference valid input
schema fields in the corresponding block
- Added comprehensive tests covering valid/invalid sink names, nested
inputs, and default key handling
- Updated `ReadDiscordMessagesBlock` description to clarify it reads new
messages and triggers on new posts
- Removed leftover test function file

## Test plan

- [ ] Run `pytest` on `validator_test.py` to verify all sink input
validation cases pass
- [ ] Verify existing agent validation flow is unaffected
- [ ] Confirm `ReadDiscordMessagesBlock` description update is accurate

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-03-25 16:08:17 +00:00
Ubbe
f08e52dc86 fix(frontend): marketplace card description 3 lines + fallback color (#12557)
## Summary
- Increase the marketplace StoreCard description from 2 lines to 3 lines
for better readability
- Change fallback background colour for missing agent images from
`bg-violet-50` to `rgb(216, 208, 255)`

<img width="933" height="458" alt="Screenshot 2026-03-25 at 20 25 41"
src="https://github.com/user-attachments/assets/ea433741-1397-4585-b64c-c7c3b8109584"
/>
<img width="350" height="457" alt="Screenshot 2026-03-25 at 20 25 55"
src="https://github.com/user-attachments/assets/e2029c09-518a-4404-aa95-e202b4064d0b"
/>


## Test plan
- [x] Verified `pnpm format`, `pnpm lint`, `pnpm types` all pass
- [x] Visually confirmed description shows 3 lines on marketplace cards
- [x] Visually confirmed fallback color renders correctly for cards
without images

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 20:58:45 +08:00
Ubbe
500b345b3b fix(frontend): auto-reconnect copilot chat after device sleep/wake (#12519)
## Summary

- Adds `visibilitychange`-based sleep/wake detection to the copilot chat
— when the page becomes visible after >30s hidden, automatically refetch
the session and either resume an active stream or hydrate completed
messages
- Blocks chat input during re-sync (`isSyncing` state) to prevent users
from accidentally sending a message that overwrites the agent's
completed work
- Replaces `PulseLoader` with a spinning `CircleNotch` icon on sidebar
session names for background streaming sessions (closer to ChatGPT's UX)

## How it works

1. When the page goes hidden, we record a timestamp
2. When the page becomes visible, we check elapsed time
3. If >30s elapsed (indicating sleep or long background), we refetch the
session from the API
4. If backend still has `active_stream=true` → remove stale assistant
message and resume SSE
5. If backend is done → the refetch triggers React Query invalidation
which hydrates the completed messages
6. Chat input stays disabled (`isSyncing=true`) until re-sync completes

## Test plan

- [ ] Open copilot, start a long-running agent task
- [ ] Close laptop lid / lock screen for >30 seconds
- [ ] Wake device — verify chat shows the agent's completed response (or
resumes streaming)
- [ ] Verify chat input is temporarily disabled during re-sync, then
re-enables
- [ ] Verify sidebar shows spinning icon (not pulse loader) for
background sessions
- [ ] Verify no duplicate messages appear after wake
- [ ] Verify normal streaming (no sleep) still works as expected

Resolves: [SECRT-2159](https://linear.app/autogpt/issue/SECRT-2159)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 20:15:33 +08:00
Ubbe
995dd1b5f3 feat(platform): replace suggestion pills with themed prompt categories (#12515)
## Summary

<img width="700" height="575" alt="Screenshot 2026-03-23 at 21 40 07"
src="https://github.com/user-attachments/assets/f6138c63-dd5e-4bde-a2e4-7434d0d3ec72"
/>

Re-applies #12452 which was reverted as collateral in #12485 (invite
system revert).

Replaces the flat list of suggestion pills in the CoPilot empty session
with themed prompt categories (Learn, Create, Automate, Organize), each
shown as a popover with contextual prompts.

- **Backend**: Adds `suggested_prompts` as a themed `dict[str,
list[str]]` keyed by category. Updates Tally extraction LLM prompt to
generate prompts per theme, and the `/suggested-prompts` API to return
grouped themes. Legacy `list[str]` rows are preserved under a
`"General"` key for backward compatibility.
- **Frontend**: Replaces inline pill buttons with a `SuggestionThemes`
popover component. Each theme button (with icon) opens a dropdown of 5
relevant prompts. Falls back to hardcoded defaults when the API has no
personalized prompts. Normalizes partial API responses by padding
missing themes with defaults. Legacy `"General"` prompts are distributed
round-robin across themes.

### Changes 🏗️

- `backend/data/understanding.py`: `suggested_prompts` field added as
`dict[str, list[str]]`; legacy list rows preserved under `"General"` key
via `_json_to_themed_prompts`
- `backend/data/tally.py`: LLM prompt updated to generate themed
prompts; validation now per-theme with blank-string rejection
- `backend/api/features/chat/routes.py`: New `SuggestedTheme` model;
endpoint returns `themes[]`
- `frontend/copilot/components/EmptySession/EmptySession.tsx`: Uses
generated API hooks for suggested prompts
- `frontend/copilot/components/EmptySession/helpers.ts`:
`DEFAULT_THEMES` replaces `DEFAULT_QUICK_ACTIONS`; `getSuggestionThemes`
normalizes partial API responses
-
`frontend/copilot/components/EmptySession/components/SuggestionThemes/`:
New popover component with theme icons and loading states

### Checklist 📋

- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verify themed suggestion buttons render on CoPilot empty session
  - [x] Click each theme button and confirm popover opens with prompts
  - [x] Click a prompt and confirm it sends the message
- [x] Verify fallback to default themes when API returns no custom
prompts
- [x] Verify legacy users' personalized prompts are preserved and
visible

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 15:32:49 +08:00
Zamil Majdy
336114f217 fix(backend): prevent graph execution stuck + steer SDK away from bash_exec (#12548)
## Summary

Two backend fixes for CoPilot stability:

1. **Steer model away from bash_exec for SDK tool-result files** — When
the SDK returns tool results as file paths, the copilot model was
attempting to use `bash_exec` to read them instead of treating the
content directly. Added system prompt guidance to prevent this.

2. **Guard against missing 'name' in execution input_data** —
`GraphExecution.from_db()` assumed all INPUT/OUTPUT block node
executions have a `name` field in `input_data`. This crashes with
`KeyError: 'name'` when non-standard blocks (e.g., OrchestratorBlock)
produce node executions without this field. Added `"name" in
exec.input_data` guards.

## Why

- The bash_exec issue causes copilot to fail when processing SDK tool
outputs
- The KeyError crashes the `update_graph_execution_stats` endpoint,
causing graph executions to appear stuck (retries 35+ times, never
completes)

## How

- Added system prompt instruction to treat tool result file contents
directly
- Added `"name" in exec.input_data` guard in both input extraction (line
340) and output extraction (line 365) in `execution.py`

### Changes
- `backend/copilot/sdk/service.py` — system prompt guidance
- `backend/data/execution.py` — KeyError guard for missing `name` field

### Checklist 📋
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan

#### Test plan:
- [x] OrchestratorBlock graph execution no longer gets stuck
- [x] Standard Agent Input/Output blocks still work correctly
- [x] Copilot SDK tool results are processed without bash_exec
2026-03-25 13:58:24 +07:00
Nicholas Tindle
866563ad25 feat(platform): admin preview marketplace submissions before approving (#12536)
## Why

Admins reviewing marketplace submissions currently approve blindly —
they can see raw metadata in the admin table but cannot see what the
listing actually looks like (images, video, branding, layout). This
risks approving inappropriate content. With full-scale production
approaching, this is critical.

Additionally, when a creator un-publishes an agent, users who already
added it to their library lose access — breaking their workflows.
Product decided on a "you added it, you keep it" model.

## What

- **Admin preview page** at `/admin/marketplace/preview/[id]` — renders
the listing exactly as it would appear on the public marketplace
- **Add to Library** for admins to test-run pending agents before
approving
- **Library membership grants graph access** — if you added an agent to
your library, you keep access even if it's un-published or rejected
- **Preview button** on every submission row in the admin marketplace
table
- **Cross-reference comments** on original functions to prevent
SECRT-2162-style regressions

## How

### Backend

**Admin preview (`store/db.py`):**
- `get_store_agent_details_as_admin()` queries `StoreListingVersion`
directly, bypassing the APPROVED-only `StoreAgent` DB view
- Validates `CreatorProfile` FK integrity, reads all fields including
`recommendedScheduleCron`

**Admin add-to-library (`library/_add_to_library.py`):**
- Extracted shared logic into `resolve_graph_for_library()` +
`add_graph_to_library()` — eliminates duplication between public and
admin paths
- Admin path uses `get_graph_as_admin()` to bypass marketplace status
checks
- Handles concurrent double-click race via `UniqueViolationError` catch

**Library membership grants graph access (`data/graph.py`):**
- `get_graph()` now falls back to `LibraryAgent` lookup if ownership and
marketplace checks fail
- Only for authenticated users with non-deleted, non-archived library
records
- `validate_graph_execution_permissions()` updated to match — library
membership grants execution access too

**New endpoints (`store_admin_routes.py`):**
- `GET /admin/submissions/{id}/preview` — returns `StoreAgentDetails`
- `POST /admin/submissions/{id}/add-to-library` — creates `LibraryAgent`
via admin path

### Frontend

- Preview page reuses `AgentInfo` + `AgentImages` with admin banner
- Shows instructions, recommended schedule, and slug
- "Add to My Library" button wired to admin endpoint
- Preview button added to `ExpandableRow` (header + version history)
- Categories column uncommented in version history table

### Testing (19 tests)

**Graph access control (9 in `graph_test.py`):** Owner access,
marketplace access, library member access (unpublished),
deleted/archived/anonymous denied, null FK denied, efficiency checks

**Admin bypass (5 in `store_admin_routes_test.py`):** Preview uses
StoreListingVersion not StoreAgent, admin path uses get_graph_as_admin,
regular path uses get_graph, library member can view in builder

**Security (3):** Non-admin 403 on preview, non-admin 403 on
add-to-library, nonexistent 404

**SECRT-2162 regression (2):** Admin access to pending agent, export
with sub-graphs

### Checklist
- [x] Changes clearly listed
- [x] Test plan made
- [x] 19 backend tests pass
- [x] Frontend lints and types clean

## Test plan
- [x] Navigate to `/admin/marketplace`, click Preview on a PENDING
submission
- [x] Verify images, video, description, categories, instructions,
schedule render correctly
- [x] Click "Add to My Library", verify agent appears in library and
opens in builder
- [x] Verify non-admin users get 403
- [x] Verify un-publishing doesn't break access for users who already
added it

🤖 Generated with [Claude Code](https://claude.com/claude-code)

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **High Risk**
> Adds new admin-only endpoints that bypass marketplace
approval/ownership checks and changes `get_graph`/execution
authorization to grant access via library membership, which impacts
security-sensitive access control paths.
> 
> **Overview**
> Adds **admin preview + review workflow support** for marketplace
submissions: new admin routes to `GET /admin/submissions/{id}/preview`
(querying `StoreListingVersion` directly) and `POST
/admin/submissions/{id}/add-to-library` (admin bypass to pull pending
graphs into an admin’s library).
> 
> Refactors library add-from-store logic into shared helpers
(`resolve_graph_for_library`, `add_graph_to_library`) and introduces an
admin variant `add_store_agent_to_library_as_admin`, including restore
of archived/deleted entries and dedup/race handling.
> 
> Changes core graph access rules: `get_graph()` now falls back to
**library membership** (non-deleted/non-archived, version-specific) when
ownership and marketplace approval don’t apply, and
`validate_graph_execution_permissions()` is updated accordingly.
Frontend adds a preview link and a dedicated admin preview page with
“Add to My Library”; tests expand significantly to lock in the new
bypass and access-control behavior.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
a362415d12. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 04:26:36 +00:00
Zamil Majdy
e79928a815 fix(backend): prevent logging sensitive data in SafeJson fallback (#12547)
### Why / What / How

**Why:** GitHub's code scanning detected a HIGH severity security
vulnerability in `/autogpt_platform/backend/backend/util/json.py:172`.
The error handler in `sanitize_json()` was logging sensitive data
(potentially including secrets, API keys, credentials) as clear text
when serialization fails.

**What:** This PR removes the logging of actual data content from the
error handler while preserving useful debugging metadata (error type,
error message, and data type).

**How:** Removed the `"Data preview: %s"` format parameter and the
corresponding `truncate(str(data), 100)` argument from the
logger.error() call. The error handler now logs only safe metadata that
helps debugging without exposing sensitive information.

### Changes 🏗️

- **Security Fix**: Modified `sanitize_json()` function in
`backend/util/json.py`
- Removed logging of data content (`truncate(str(data), 100)`) from the
error handler
  - Retained logging of error type (`type(e).__name__`)
- Retained logging of truncated error message (`truncate(str(e), 200)`)
  - Retained logging of data type (`type(data).__name__`)
- Error handler still provides useful debugging information without
exposing secrets

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Verified the code passes type checking (`poetry run pyright
backend/util/json.py`)
- [x] Verified the code passes linting (`poetry run ruff check
backend/util/json.py`)
  - [x] Verified all pre-commit hooks pass
- [x] Reviewed the diff to ensure only the sensitive data logging was
removed
- [x] Confirmed that useful debugging information (error type, error
message, data type) is still logged

#### For configuration changes:
- N/A - No configuration changes required
2026-03-25 04:21:21 +00:00
Zamil Majdy
1771ed3bef dx(skills): codify PR workflow rules in skill docs and CLAUDE.md (#12531)
## Summary

- **pr-address skill**: Add explicit rule against empty commits for CI
re-triggers, and strengthen push-immediately guidance with rationale
- **Platform CLAUDE.md**: Add "split PRs by concern" guideline under
Creating Pull Requests

### Changes
- Updated `.claude/skills/pr-address/SKILL.md`
- Updated `autogpt_platform/CLAUDE.md`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan

#### Test plan:
- [x] Documentation-only changes — no functional tests needed
- [x] Verified markdown renders correctly
2026-03-25 10:19:30 +07:00
Zamil Majdy
550fa5a319 fix(backend): register AutoPilot sessions with stream registry for SSE updates (#12500)
### Changes 🏗️
- When the AutoPilot block executes a copilot session via
`collect_copilot_response`, it calls `stream_chat_completion_sdk`
directly, bypassing the copilot executor and stream registry. This means
the frontend sees no `active_stream` on the session and cannot connect
via SSE — users see a frozen chat with no updates until the turn fully
completes.
- Fix: register a `stream_registry` session in
`collect_copilot_response` and publish each chunk to Redis as events are
consumed. This allows the frontend to detect `active_stream=true` and
connect via the SSE reconnect endpoint for live streaming updates during
AutoPilot execution.
- Error handling is graceful — if stream registry fails, AutoPilot still
works normally, just without real-time frontend updates.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Trigger an AutoPilot block execution that creates a new chat
session
- [x] Verify the new session appears in the sidebar with streaming
indicator
- [x] Click on the session while AutoPilot is still executing — verify
SSE connects and messages stream in real-time
- [x] Verify that after AutoPilot completes, the session shows as
complete (no active_stream)
- [x] Test reconnection: disconnect and reconnect while AutoPilot is
running — verify stream resumes (found and fixed GeneratorExit bug that
caused stuck sessions)
- [x] E2E: 10 stream events published to Redis (StreamStart,
3×ToolInput, 3×ToolOutput, TextStart, TextEnd, StreamFinish)
  - [x] E2E: Redis xadd latency 0.2–3.4ms per chunk
  - [x] E2E: Chat sessions registered in Redis (confirmed via redis-cli)
2026-03-25 01:08:49 +00:00
Zamil Majdy
8528dffbf2 fix(backend): allow /tmp as valid path in E2B sandbox file tools (#12501)
## Summary
- Allow `/tmp` as a valid writable directory in E2B sandbox file tools
(`write_file`, `read_file`, `edit_file`, `glob`, `grep`)
- The E2B sandbox is already fully isolated, so restricting writes to
only `/home/user` was unnecessarily limiting — scripts and tools
commonly use `/tmp` for temporary files
- Extract `is_within_allowed_dirs()` helper in `context.py` to
centralize the allowed-directory check for both path resolution and
symlink escape detection

## Changes
- `context.py`: Add `E2B_ALLOWED_DIRS` tuple and `E2B_ALLOWED_DIRS_STR`,
introduce `is_within_allowed_dirs()`, update `resolve_sandbox_path()` to
use it
- `e2b_file_tools.py`: Update `_check_sandbox_symlink_escape()` to use
`is_within_allowed_dirs()`, update tool descriptions
- Tests: Add coverage for `/tmp` paths in both `context_test.py` and
`e2b_file_tools_test.py`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] All 59 existing + new tests pass (`poetry run pytest
backend/copilot/context_test.py
backend/copilot/sdk/e2b_file_tools_test.py`)
  - [x] `poetry run format` and `poetry run lint` pass clean
  - [x] Verify `/tmp` write works in live E2B sandbox
  - [x] E2E: Write file to /tmp/test.py in E2B sandbox via copilot
  - [x] E2E: Execute script from /tmp — output "Hello, World!"
  - [x] E2E: E2B sandbox lifecycle (create, use, pause) works correctly
2026-03-25 00:52:58 +00:00
Zamil Majdy
8fbf6a4b09 Merge branch 'master' of github.com:Significant-Gravitas/AutoGPT into dev 2026-03-25 06:55:47 +07:00
Zamil Majdy
239148596c fix(backend): filter SDK default credentials from credentials API responses (#12544)
## Summary

- Filter SDK-provisioned default credentials from credentials API list
endpoints
- Reuse `CredentialsMetaResponse` model from internal router in external
API (removes duplicate `CredentialSummary`)
- Add `is_sdk_default()` helper for identifying platform-provisioned
credentials
- Add `provider_matches()` to credential store for consistent provider
filtering
- Add tests for credential filtering behavior

### Changes
- `backend/data/model.py` — add `is_sdk_default()` helper
- `backend/api/features/integrations/router.py` — filter SDK defaults
from list endpoints
- `backend/api/external/v1/integrations.py` — reuse
`CredentialsMetaResponse`, filter SDK defaults
- `backend/integrations/credentials_store.py` — add `provider_matches()`
- `backend/sdk/registry.py` — update credential registration
- `backend/api/features/integrations/router_test.py` — new tests
- `backend/api/features/integrations/conftest.py` — test fixtures

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan

#### Test plan:
- [x] Unit tests for credential filtering (`router_test.py`)
- [x] Verify SDK default credentials excluded from API responses
- [x] Verify user-created credentials still returned normally
2026-03-25 06:54:54 +07:00
Zamil Majdy
a880d73481 feat(platform): dry-run execution mode with LLM block simulation (#12483)
## Why

Agent generation and building needs a way to test-run agents without
requiring real credentials or producing side effects. Currently, every
execution hits real APIs, consumes credits, and requires valid
credentials — making it impossible to debug or validate agent graphs
during the build phase without real consequences.

## Summary

Adds a `dry_run` execution mode to the copilot's `run_block` and
`run_agent` tools. When `dry_run=True`, every block execution is
simulated by an LLM instead of calling the real service — no real API
calls, no credentials consumed, no side effects.

Inspired by
[Significant-Gravitas/agent-simulator](https://github.com/Significant-Gravitas/agent-simulator).

### How it works

- **`backend/executor/simulator.py`** (new): `simulate_block()` builds a
prompt from the block's name, description, input/output schemas, and
actual input values, then calls `gpt-4o-mini` via the existing
OpenRouter client with JSON mode. Retries up to 5 times on JSON parse
failures. Missing output pins are filled with `None` (or `""` for the
`error` pin). Long inputs (>20k chars) are truncated before sending to
the LLM.
- **`ExecutionContext`**: Added `dry_run: bool = False` field; threaded
through `add_graph_execution()` so graph-level dry runs propagate to
every block execution.
- **`execute_block()` helper**: When `dry_run=True`, the function
short-circuits before any credential injection or credit checks, calls
`simulate_block()`, and returns a `[DRY RUN]`-prefixed
`BlockOutputResponse`.
- **`RunBlockTool`**: New `dry_run` boolean parameter.
- **`RunAgentTool`**: New `dry_run` boolean parameter; passes
`ExecutionContext(dry_run=True)` to graph execution.

### Tests

11 tests in `backend/copilot/tools/test_dry_run.py`:
- Correct output tuples from LLM response
- JSON retry logic (3 total calls when first 2 fail)
- All-retries-exhausted yields `SIMULATOR ERROR`
- Missing output pins filled with `None`/`""`
- No-client case
- Input truncation at 20k chars
- `execute_block(dry_run=True)` skips real `block.execute()`
- Response format: `[DRY RUN]` message, `success=True`
- `dry_run=False` unchanged (real path)
- `RunBlockTool` parameter presence
- `dry_run` kwarg forwarding

## Test plan
- [x] Run `pytest backend/copilot/tools/test_dry_run.py -v` — all 11
pass
- [x] Call `run_block` with `dry_run=true` in copilot; verify no real
API calls occur and output contains `[DRY RUN]`
- [x] Call `run_agent` with `dry_run=true`; verify execution is created
with `dry_run=True` in context
- [x] E2E: Simulate button (flask icon) present in builder alongside
play button
- [x] E2E: Simulated run labeled with "(Simulated)" suffix and badge in
Library
- [x] E2E: No credits consumed during dry-run
2026-03-24 22:36:47 +00:00
215 changed files with 18921 additions and 1725 deletions

View File

@@ -0,0 +1,106 @@
---
name: open-pr
description: Open a pull request with proper PR template, test coverage, and review workflow. Guides agents through creating a PR that follows repo conventions, ensures existing behaviors aren't broken, covers new behaviors with tests, and handles review via bot when local testing isn't possible. TRIGGER when user asks to "open a PR", "create a PR", "make a PR", "submit a PR", "open pull request", "push and create PR", or any variation of opening/submitting a pull request.
user-invocable: true
args: "[base-branch] — optional target branch (defaults to dev)."
metadata:
author: autogpt-team
version: "1.0.0"
---
# Open a Pull Request
## Step 1: Pre-flight checks
Before opening the PR:
1. Ensure all changes are committed
2. Ensure the branch is pushed to the remote (`git push -u origin <branch>`)
3. Run linters/formatters across the whole repo (not just changed files) and commit any fixes
## Step 2: Test coverage
**This is critical.** Before opening the PR, verify:
### Existing behavior is not broken
- Identify which modules/components your changes touch
- Run the existing test suites for those areas
- If tests fail, fix them before opening the PR — do not open a PR with known regressions
### New behavior has test coverage
- Every new feature, endpoint, or behavior change needs tests
- If you added a new block, add tests for that block
- If you changed API behavior, add or update API tests
- If you changed frontend behavior, verify it doesn't break existing flows
If you cannot run the full test suite locally, note which tests you ran and which you couldn't in the test plan.
## Step 3: Create the PR using the repo template
Read the canonical PR template at `.github/PULL_REQUEST_TEMPLATE.md` and use it **verbatim** as your PR body:
1. Read the template: `cat .github/PULL_REQUEST_TEMPLATE.md`
2. Preserve the exact section titles and formatting, including:
- `### Why / What / How`
- `### Changes 🏗️`
- `### Checklist 📋`
3. Replace HTML comment prompts (`<!-- ... -->`) with actual content; do not leave them in
4. **Do not pre-check boxes** — leave all checkboxes as `- [ ]` until each step is actually completed
5. Do not alter the template structure, rename sections, or remove any checklist items
**PR title must use conventional commit format** (e.g., `feat(backend): add new block`, `fix(frontend): resolve routing bug`, `dx(skills): update PR workflow`). See CLAUDE.md for the full list of scopes.
Use `gh pr create` with the base branch (defaults to `dev` if no `[base-branch]` was provided). Use `--body-file` to avoid shell interpretation of backticks and special characters:
```bash
BASE_BRANCH="${BASE_BRANCH:-dev}"
PR_BODY=$(mktemp)
cat > "$PR_BODY" << 'PREOF'
<filled-in template from .github/PULL_REQUEST_TEMPLATE.md>
PREOF
gh pr create --base "$BASE_BRANCH" --title "<type>(scope): short description" --body-file "$PR_BODY"
rm "$PR_BODY"
```
## Step 4: Review workflow
### If you have a workspace that allows testing (docker, running backend, etc.)
- Run `/pr-test` to do E2E manual testing of the PR using docker compose, agent-browser, and API calls. This is the most thorough way to validate your changes before review.
- After testing, run `/pr-review` to self-review the PR for correctness, security, code quality, and testing gaps before requesting human review.
### If you do NOT have a workspace that allows testing
This is common for agents running in worktrees without a full stack. In this case:
1. Run `/pr-review` locally to catch obvious issues before pushing
2. **Comment `/review` on the PR** after creating it to trigger the review bot
3. **Poll for the review** rather than blindly waiting — check for new review comments every 30 seconds using `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` and the GraphQL inline threads query. The bot typically responds within 30 minutes, but polling lets the agent react as soon as it arrives.
4. Do NOT proceed or merge until the bot review comes back
5. Address any issues the bot raises — use `/pr-address` which has a full polling loop with CI + comment tracking
```bash
# After creating the PR:
PR_NUMBER=$(gh pr view --json number -q .number)
gh pr comment "$PR_NUMBER" --body "/review"
# Then use /pr-address to poll for and address the review when it arrives
```
## Step 5: Address review feedback
Once the review bot or human reviewers leave comments:
- Run `/pr-address` to address review comments. It will loop until CI is green and all comments are resolved.
- Do not merge without human approval.
## Related skills
| Skill | When to use |
|---|---|
| `/pr-test` | E2E testing with docker compose, agent-browser, API calls — use when you have a running workspace |
| `/pr-review` | Review for correctness, security, code quality — use before requesting human review |
| `/pr-address` | Address reviewer comments and loop until CI green — use after reviews come in |
## Step 6: Post-creation
After the PR is created and review is triggered:
- Share the PR URL with the user
- If waiting on the review bot, let the user know the expected wait time (~30 min)
- Do not merge without human approval

View File

@@ -113,7 +113,9 @@ kill $REST_PID 2>/dev/null; trap - EXIT
```
Never manually edit files in `src/app/api/__generated__/`.
Then commit and **push immediately** — never batch commits without pushing.
Then commit and **push immediately** — never batch commits without pushing. Each fix should be visible on GitHub right away so CI can start and reviewers can see progress.
**Never push empty commits** (`git commit --allow-empty`) to re-trigger CI or bot checks. When a check fails, investigate the root cause (unchecked PR checklist, unaddressed review comments, code issues) and fix those directly. Empty commits add noise to git history.
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).

View File

@@ -0,0 +1,195 @@
---
name: setup-repo
description: Initialize a worktree-based repo layout for parallel development. Creates a main worktree, a reviews worktree for PR reviews, and N numbered work branches. Handles .env creation, dependency installation, and branchlet config. TRIGGER when user asks to set up the repo from scratch, initialize worktrees, bootstrap their dev environment, "setup repo", "setup worktrees", "initialize dev environment", "set up branches", or when a freshly cloned repo has no sibling worktrees.
user-invocable: true
args: "No arguments — interactive setup via prompts."
metadata:
author: autogpt-team
version: "1.0.0"
---
# Repository Setup
This skill sets up a worktree-based development layout from a freshly cloned repo. It creates:
- A **main** worktree (the primary checkout)
- A **reviews** worktree (for PR reviews)
- **N work branches** (branch1..branchN) for parallel development
## Step 1: Identify the repo
Determine the repo root and parent directory:
```bash
ROOT=$(git rev-parse --show-toplevel)
REPO_NAME=$(basename "$ROOT")
PARENT=$(dirname "$ROOT")
```
Detect if the repo is already inside a worktree layout by counting sibling worktrees (not just checking the directory name, which could be anything):
```bash
# Count worktrees that are siblings (live under $PARENT but aren't $ROOT itself)
SIBLING_COUNT=$(git worktree list --porcelain 2>/dev/null | grep "^worktree " | grep -c "$PARENT/" || true)
if [ "$SIBLING_COUNT" -gt 1 ]; then
echo "INFO: Existing worktree layout detected at $PARENT ($SIBLING_COUNT worktrees)"
# Use $ROOT as-is; skip renaming/restructuring
else
echo "INFO: Fresh clone detected, proceeding with setup"
fi
```
## Step 2: Ask the user questions
Use AskUserQuestion to gather setup preferences:
1. **How many parallel work branches do you need?** (Options: 4, 8, 16, or custom)
- These become `branch1` through `branchN`
2. **Which branch should be the base?** (Options: origin/master, origin/dev, or custom)
- All work branches and reviews will start from this
## Step 3: Fetch and set up branches
```bash
cd "$ROOT"
git fetch origin
# Create the reviews branch from base (skip if already exists)
if git show-ref --verify --quiet refs/heads/reviews; then
echo "INFO: Branch 'reviews' already exists, skipping"
else
git branch reviews <base-branch>
fi
# Create numbered work branches from base (skip if already exists)
for i in $(seq 1 "$COUNT"); do
if git show-ref --verify --quiet "refs/heads/branch$i"; then
echo "INFO: Branch 'branch$i' already exists, skipping"
else
git branch "branch$i" <base-branch>
fi
done
```
## Step 4: Create worktrees
Create worktrees as siblings to the main checkout:
```bash
if [ -d "$PARENT/reviews" ]; then
echo "INFO: Worktree '$PARENT/reviews' already exists, skipping"
else
git worktree add "$PARENT/reviews" reviews
fi
for i in $(seq 1 "$COUNT"); do
if [ -d "$PARENT/branch$i" ]; then
echo "INFO: Worktree '$PARENT/branch$i' already exists, skipping"
else
git worktree add "$PARENT/branch$i" "branch$i"
fi
done
```
## Step 5: Set up environment files
**Do NOT assume .env files exist.** For each worktree (including main if needed):
1. Check if `.env` exists in the source worktree for each path
2. If `.env` exists, copy it
3. If only `.env.default` or `.env.example` exists, copy that as `.env`
4. If neither exists, warn the user and list which env files are missing
Env file locations to check (same as the `/worktree` skill — keep these in sync):
- `autogpt_platform/.env`
- `autogpt_platform/backend/.env`
- `autogpt_platform/frontend/.env`
> **Note:** This env copying logic intentionally mirrors the `/worktree` skill's approach. If you update the path list or fallback logic here, update `/worktree` as well.
```bash
SOURCE="$ROOT"
WORKTREES="reviews"
for i in $(seq 1 "$COUNT"); do WORKTREES="$WORKTREES branch$i"; done
FOUND_ANY_ENV=0
for wt in $WORKTREES; do
TARGET="$PARENT/$wt"
for envpath in autogpt_platform autogpt_platform/backend autogpt_platform/frontend; do
if [ -f "$SOURCE/$envpath/.env" ]; then
FOUND_ANY_ENV=1
cp "$SOURCE/$envpath/.env" "$TARGET/$envpath/.env"
elif [ -f "$SOURCE/$envpath/.env.default" ]; then
FOUND_ANY_ENV=1
cp "$SOURCE/$envpath/.env.default" "$TARGET/$envpath/.env"
echo "NOTE: $wt/$envpath/.env was created from .env.default — you may need to edit it"
elif [ -f "$SOURCE/$envpath/.env.example" ]; then
FOUND_ANY_ENV=1
cp "$SOURCE/$envpath/.env.example" "$TARGET/$envpath/.env"
echo "NOTE: $wt/$envpath/.env was created from .env.example — you may need to edit it"
else
echo "WARNING: No .env, .env.default, or .env.example found at $SOURCE/$envpath/"
fi
done
done
if [ "$FOUND_ANY_ENV" -eq 0 ]; then
echo "WARNING: No environment files or templates were found in the source worktree."
# Use AskUserQuestion to confirm: "Continue setup without env files?"
# If the user declines, stop here and let them set up .env files first.
fi
```
## Step 6: Copy branchlet config
Copy `.branchlet.json` from main to each worktree so branchlet can manage sub-worktrees:
```bash
if [ -f "$ROOT/.branchlet.json" ]; then
for wt in $WORKTREES; do
cp "$ROOT/.branchlet.json" "$PARENT/$wt/.branchlet.json"
done
fi
```
## Step 7: Install dependencies
Install deps in all worktrees. Run these sequentially per worktree:
```bash
for wt in $WORKTREES; do
TARGET="$PARENT/$wt"
echo "=== Installing deps for $wt ==="
(cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install) &&
(cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate) &&
(cd "$TARGET/autogpt_platform/frontend" && pnpm install) &&
echo "=== Done: $wt ===" ||
echo "=== FAILED: $wt ==="
done
```
This is slow. Run in background if possible and notify when complete.
## Step 8: Verify and report
After setup, verify and report to the user:
```bash
git worktree list
```
Summarize:
- Number of worktrees created
- Which env files were copied vs created from defaults vs missing
- Any warnings or errors encountered
## Final directory layout
```
parent/
main/ # Primary checkout (already exists)
reviews/ # PR review worktree
branch1/ # Work branch 1
branch2/ # Work branch 2
...
branchN/ # Work branch N
```

View File

@@ -83,13 +83,13 @@ The AutoGPT frontend is where users interact with our powerful AI automation pla
**Agent Builder:** For those who want to customize, our intuitive, low-code interface allows you to design and configure your own AI agents.
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
**Deployment Controls:** Manage the lifecycle of your agents, from testing to production.
**Ready-to-Use Agents:** Don't want to build? Simply select from our library of pre-configured agents and put them to work immediately.
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
**Monitoring and Analytics:** Keep track of your agents' performance and gain insights to continually improve your automation processes.

View File

@@ -53,6 +53,7 @@ AutoGPT Platform is a monorepo containing:
### Creating Pull Requests
- Create the PR against the `dev` branch of the repository.
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
- Use conventional commit messages (see below)
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.

View File

@@ -178,6 +178,7 @@ SMTP_USERNAME=
SMTP_PASSWORD=
# Business & Marketing Tools
AGENTMAIL_API_KEY=
APOLLO_API_KEY=
ENRICHLAYER_API_KEY=
AYRSHARE_API_KEY=

View File

@@ -61,6 +61,7 @@ poetry run pytest path/to/test.py --snapshot-update
## Code Style
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
- **Pydantic models** over dataclass/namedtuple/dict for structured data
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code

View File

@@ -18,14 +18,22 @@ from pydantic import BaseModel, Field, SecretStr
from backend.api.external.middleware import require_permission
from backend.api.features.integrations.models import get_all_provider_names
from backend.api.features.integrations.router import (
CredentialsMetaResponse,
to_meta_response,
)
from backend.data.auth.base import APIAuthorizationInfo
from backend.data.model import (
APIKeyCredentials,
Credentials,
CredentialsType,
HostScopedCredentials,
OAuth2Credentials,
UserPasswordCredentials,
is_sdk_default,
)
from backend.integrations.credentials_store import (
is_system_credential,
provider_matches,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
@@ -91,18 +99,6 @@ class OAuthCompleteResponse(BaseModel):
)
class CredentialSummary(BaseModel):
"""Summary of a credential without sensitive data."""
id: str
provider: str
type: CredentialsType
title: Optional[str] = None
scopes: Optional[list[str]] = None
username: Optional[str] = None
host: Optional[str] = None
class ProviderInfo(BaseModel):
"""Information about an integration provider."""
@@ -473,12 +469,12 @@ async def complete_oauth(
)
@integrations_router.get("/credentials", response_model=list[CredentialSummary])
@integrations_router.get("/credentials", response_model=list[CredentialsMetaResponse])
async def list_credentials(
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_INTEGRATIONS)
),
) -> list[CredentialSummary]:
) -> list[CredentialsMetaResponse]:
"""
List all credentials for the authenticated user.
@@ -486,28 +482,19 @@ async def list_credentials(
"""
credentials = await creds_manager.store.get_all_creds(auth.user_id)
return [
CredentialSummary(
id=cred.id,
provider=cred.provider,
type=cred.type,
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
]
@integrations_router.get(
"/{provider}/credentials", response_model=list[CredentialSummary]
"/{provider}/credentials", response_model=list[CredentialsMetaResponse]
)
async def list_credentials_by_provider(
provider: Annotated[str, Path(title="The provider to list credentials for")],
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.READ_INTEGRATIONS)
),
) -> list[CredentialSummary]:
) -> list[CredentialsMetaResponse]:
"""
List credentials for a specific provider.
"""
@@ -515,16 +502,7 @@ async def list_credentials_by_provider(
auth.user_id, provider
)
return [
CredentialSummary(
id=cred.id,
provider=cred.provider,
type=cred.type,
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
]
@@ -597,11 +575,11 @@ async def create_credential(
# Store credentials
try:
await creds_manager.create(auth.user_id, credentials)
except Exception as e:
logger.error(f"Failed to store credentials: {e}")
except Exception:
logger.exception("Failed to store credentials")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to store credentials: {str(e)}",
detail="Failed to store credentials",
)
logger.info(f"Created {request.type} credentials for provider {provider}")
@@ -639,15 +617,23 @@ async def delete_credential(
use the main API's delete endpoint which handles webhook cleanup and
token revocation.
"""
if is_sdk_default(cred_id):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if is_system_credential(cred_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="System-managed credentials cannot be deleted",
)
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
if not creds:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if creds.provider != provider:
if not provider_matches(creds.provider, provider):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Credentials do not match the specified provider",
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
await creds_manager.delete(auth.user_id, cred_id)

View File

@@ -0,0 +1,146 @@
"""Admin endpoints for checking and resetting user CoPilot rate limit usage."""
import logging
from typing import Optional
from autogpt_libs.auth import get_user_id, requires_admin_user
from fastapi import APIRouter, Body, HTTPException, Security
from pydantic import BaseModel
from backend.copilot.config import ChatConfig
from backend.copilot.rate_limit import (
get_global_rate_limits,
get_usage_status,
reset_user_usage,
)
from backend.data.user import get_user_by_email, get_user_email_by_id
logger = logging.getLogger(__name__)
config = ChatConfig()
router = APIRouter(
prefix="/admin",
tags=["copilot", "admin"],
dependencies=[Security(requires_admin_user)],
)
class UserRateLimitResponse(BaseModel):
user_id: str
user_email: Optional[str] = None
daily_token_limit: int
weekly_token_limit: int
daily_tokens_used: int
weekly_tokens_used: int
async def _resolve_user_id(
user_id: Optional[str], email: Optional[str]
) -> tuple[str, Optional[str]]:
"""Resolve a user_id and email from the provided parameters.
Returns (user_id, email). Accepts either user_id or email; at least one
must be provided. When both are provided, ``email`` takes precedence.
"""
if email:
user = await get_user_by_email(email)
if not user:
raise HTTPException(
status_code=404, detail="No user found with the provided email."
)
return user.id, email
if not user_id:
raise HTTPException(
status_code=400,
detail="Either user_id or email query parameter is required.",
)
# We have a user_id; try to look up their email for display purposes.
# This is non-critical -- a failure should not block the response.
try:
resolved_email = await get_user_email_by_id(user_id)
except Exception:
logger.warning("Failed to resolve email for user %s", user_id, exc_info=True)
resolved_email = None
return user_id, resolved_email
@router.get(
"/rate_limit",
response_model=UserRateLimitResponse,
summary="Get User Rate Limit",
)
async def get_user_rate_limit(
user_id: Optional[str] = None,
email: Optional[str] = None,
admin_user_id: str = Security(get_user_id),
) -> UserRateLimitResponse:
"""Get a user's current usage and effective rate limits. Admin-only.
Accepts either ``user_id`` or ``email`` as a query parameter.
When ``email`` is provided the user is looked up by email first.
"""
resolved_id, resolved_email = await _resolve_user_id(user_id, email)
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
daily_limit, weekly_limit = await get_global_rate_limits(
resolved_id, config.daily_token_limit, config.weekly_token_limit
)
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit)
return UserRateLimitResponse(
user_id=resolved_id,
user_email=resolved_email,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,
weekly_tokens_used=usage.weekly.used,
)
@router.post(
"/rate_limit/reset",
response_model=UserRateLimitResponse,
summary="Reset User Rate Limit Usage",
)
async def reset_user_rate_limit(
user_id: str = Body(embed=True),
reset_weekly: bool = Body(False, embed=True),
admin_user_id: str = Security(get_user_id),
) -> UserRateLimitResponse:
"""Reset a user's daily usage counter (and optionally weekly). Admin-only."""
logger.info(
"Admin %s resetting rate limit for user %s (reset_weekly=%s)",
admin_user_id,
user_id,
reset_weekly,
)
try:
await reset_user_usage(user_id, reset_weekly=reset_weekly)
except Exception as e:
logger.exception("Failed to reset user usage")
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
daily_limit, weekly_limit = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
usage = await get_usage_status(user_id, daily_limit, weekly_limit)
try:
resolved_email = await get_user_email_by_id(user_id)
except Exception:
logger.warning("Failed to resolve email for user %s", user_id, exc_info=True)
resolved_email = None
return UserRateLimitResponse(
user_id=user_id,
user_email=resolved_email,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,
weekly_tokens_used=usage.weekly.used,
)

View File

@@ -0,0 +1,263 @@
import json
from types import SimpleNamespace
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from pytest_snapshot.plugin import Snapshot
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
from .rate_limit_admin_routes import router as rate_limit_admin_router
app = fastapi.FastAPI()
app.include_router(rate_limit_admin_router)
client = fastapi.testclient.TestClient(app)
_MOCK_MODULE = "backend.api.features.admin.rate_limit_admin_routes"
_TARGET_EMAIL = "target@example.com"
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
"""Setup admin auth overrides for all tests in this module"""
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def _mock_usage_status(
daily_used: int = 500_000, weekly_used: int = 3_000_000
) -> CoPilotUsageStatus:
from datetime import UTC, datetime, timedelta
now = datetime.now(UTC)
return CoPilotUsageStatus(
daily=UsageWindow(
used=daily_used, limit=2_500_000, resets_at=now + timedelta(hours=6)
),
weekly=UsageWindow(
used=weekly_used, limit=12_500_000, resets_at=now + timedelta(days=3)
),
)
def _patch_rate_limit_deps(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
daily_used: int = 500_000,
weekly_used: int = 3_000_000,
):
"""Patch the common rate-limit + user-lookup dependencies."""
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
new_callable=AsyncMock,
return_value=_mock_usage_status(daily_used=daily_used, weekly_used=weekly_used),
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
def test_get_rate_limit(
mocker: pytest_mock.MockerFixture,
configured_snapshot: Snapshot,
target_user_id: str,
) -> None:
"""Test getting rate limit and usage for a user."""
_patch_rate_limit_deps(mocker, target_user_id)
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["user_email"] == _TARGET_EMAIL
assert data["daily_token_limit"] == 2_500_000
assert data["weekly_token_limit"] == 12_500_000
assert data["daily_tokens_used"] == 500_000
assert data["weekly_tokens_used"] == 3_000_000
configured_snapshot.assert_match(
json.dumps(data, indent=2, sort_keys=True) + "\n",
"get_rate_limit",
)
def test_get_rate_limit_by_email(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test looking up rate limits via email instead of user_id."""
_patch_rate_limit_deps(mocker, target_user_id)
mock_user = SimpleNamespace(id=target_user_id, email=_TARGET_EMAIL)
mocker.patch(
f"{_MOCK_MODULE}.get_user_by_email",
new_callable=AsyncMock,
return_value=mock_user,
)
response = client.get("/admin/rate_limit", params={"email": _TARGET_EMAIL})
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["user_email"] == _TARGET_EMAIL
assert data["daily_token_limit"] == 2_500_000
def test_get_rate_limit_by_email_not_found(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Test that looking up a non-existent email returns 404."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_by_email",
new_callable=AsyncMock,
return_value=None,
)
response = client.get("/admin/rate_limit", params={"email": "nobody@example.com"})
assert response.status_code == 404
def test_get_rate_limit_no_params() -> None:
"""Test that omitting both user_id and email returns 400."""
response = client.get("/admin/rate_limit")
assert response.status_code == 400
def test_reset_user_usage_daily_only(
mocker: pytest_mock.MockerFixture,
configured_snapshot: Snapshot,
target_user_id: str,
) -> None:
"""Test resetting only daily usage (default behaviour)."""
mock_reset = mocker.patch(
f"{_MOCK_MODULE}.reset_user_usage",
new_callable=AsyncMock,
)
_patch_rate_limit_deps(mocker, target_user_id, daily_used=0, weekly_used=3_000_000)
response = client.post(
"/admin/rate_limit/reset",
json={"user_id": target_user_id},
)
assert response.status_code == 200
data = response.json()
assert data["daily_tokens_used"] == 0
# Weekly is untouched
assert data["weekly_tokens_used"] == 3_000_000
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
configured_snapshot.assert_match(
json.dumps(data, indent=2, sort_keys=True) + "\n",
"reset_user_usage_daily_only",
)
def test_reset_user_usage_daily_and_weekly(
mocker: pytest_mock.MockerFixture,
configured_snapshot: Snapshot,
target_user_id: str,
) -> None:
"""Test resetting both daily and weekly usage."""
mock_reset = mocker.patch(
f"{_MOCK_MODULE}.reset_user_usage",
new_callable=AsyncMock,
)
_patch_rate_limit_deps(mocker, target_user_id, daily_used=0, weekly_used=0)
response = client.post(
"/admin/rate_limit/reset",
json={"user_id": target_user_id, "reset_weekly": True},
)
assert response.status_code == 200
data = response.json()
assert data["daily_tokens_used"] == 0
assert data["weekly_tokens_used"] == 0
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
configured_snapshot.assert_match(
json.dumps(data, indent=2, sort_keys=True) + "\n",
"reset_user_usage_daily_and_weekly",
)
def test_reset_user_usage_redis_failure(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that Redis failure on reset returns 500."""
mocker.patch(
f"{_MOCK_MODULE}.reset_user_usage",
new_callable=AsyncMock,
side_effect=Exception("Redis connection refused"),
)
response = client.post(
"/admin/rate_limit/reset",
json={"user_id": target_user_id},
)
assert response.status_code == 500
def test_get_rate_limit_email_lookup_failure(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that failing to resolve a user email degrades gracefully."""
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
new_callable=AsyncMock,
return_value=_mock_usage_status(),
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
side_effect=Exception("DB connection lost"),
)
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["user_email"] is None
def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
"""Test that rate limit admin endpoints require admin role."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
response = client.get("/admin/rate_limit", params={"user_id": "test"})
assert response.status_code == 403
response = client.post(
"/admin/rate_limit/reset",
json={"user_id": "test"},
)
assert response.status_code == 403

View File

@@ -7,6 +7,8 @@ import fastapi
import fastapi.responses
import prisma.enums
import backend.api.features.library.db as library_db
import backend.api.features.library.model as library_model
import backend.api.features.store.cache as store_cache
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
@@ -132,3 +134,40 @@ async def admin_download_agent_file(
return fastapi.responses.FileResponse(
tmp_file.name, filename=file_name, media_type="application/json"
)
@router.get(
"/submissions/{store_listing_version_id}/preview",
summary="Admin Preview Submission Listing",
)
async def admin_preview_submission(
store_listing_version_id: str,
) -> store_model.StoreAgentDetails:
"""
Preview a marketplace submission as it would appear on the listing page.
Bypasses the APPROVED-only StoreAgent view so admins can preview pending
submissions before approving.
"""
return await store_db.get_store_agent_details_as_admin(store_listing_version_id)
@router.post(
"/submissions/{store_listing_version_id}/add-to-library",
summary="Admin Add Pending Agent to Library",
status_code=201,
)
async def admin_add_agent_to_library(
store_listing_version_id: str,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
) -> library_model.LibraryAgent:
"""
Add a pending marketplace agent to the admin's library for review.
Uses admin-level access to bypass marketplace APPROVED-only checks.
The builder can load the graph because get_graph() checks library
membership as a fallback: "you added it, you keep it."
"""
return await library_db.add_store_agent_to_library_as_admin(
store_listing_version_id=store_listing_version_id,
user_id=user_id,
)

View File

@@ -1,14 +1,33 @@
"""Tests for admin store routes and the bypass logic they depend on.
Tests are organized by what they protect:
- SECRT-2162: get_graph_as_admin bypasses ownership/marketplace checks
- SECRT-2167 security: admin endpoints reject non-admin users
- SECRT-2167 bypass: preview queries StoreListingVersion (not StoreAgent view),
and add-to-library uses get_graph_as_admin (not get_graph)
"""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import fastapi
import fastapi.responses
import fastapi.testclient
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from backend.data.graph import get_graph_as_admin
from backend.util.exceptions import NotFoundError
from .store_admin_routes import router as store_admin_router
# Shared constants
ADMIN_USER_ID = "admin-user-id"
CREATOR_USER_ID = "other-creator-id"
GRAPH_ID = "test-graph-id"
GRAPH_VERSION = 3
SLV_ID = "test-store-listing-version-id"
def _make_mock_graph(user_id: str = CREATOR_USER_ID) -> MagicMock:
@@ -20,18 +39,18 @@ def _make_mock_graph(user_id: str = CREATOR_USER_ID) -> MagicMock:
return graph
# ---- SECRT-2162: get_graph_as_admin bypasses ownership checks ---- #
@pytest.mark.asyncio
async def test_admin_can_access_pending_agent_not_owned() -> None:
"""Admin must be able to access a graph they don't own even if it's not
APPROVED in the marketplace. This is the core use case: reviewing a
submitted-but-pending agent from the admin dashboard."""
"""get_graph_as_admin must return a graph even when the admin doesn't own
it and it's not APPROVED in the marketplace."""
mock_graph = _make_mock_graph()
mock_graph_model = MagicMock(name="GraphModel")
with (
patch(
"backend.data.graph.AgentGraph.prisma",
) as mock_prisma,
patch("backend.data.graph.AgentGraph.prisma") as mock_prisma,
patch(
"backend.data.graph.GraphModel.from_db",
return_value=mock_graph_model,
@@ -46,25 +65,19 @@ async def test_admin_can_access_pending_agent_not_owned() -> None:
for_export=False,
)
assert (
result is not None
), "Admin should be able to access a pending agent they don't own"
assert result is mock_graph_model
@pytest.mark.asyncio
async def test_admin_download_pending_agent_with_subagents() -> None:
"""Admin export (for_export=True) of a pending agent must include
sub-graphs. This exercises the full export code path that the Download
button uses."""
"""get_graph_as_admin with for_export=True must call get_sub_graphs
and pass sub_graphs to GraphModel.from_db."""
mock_graph = _make_mock_graph()
mock_sub_graph = MagicMock(name="SubGraph")
mock_graph_model = MagicMock(name="GraphModel")
with (
patch(
"backend.data.graph.AgentGraph.prisma",
) as mock_prisma,
patch("backend.data.graph.AgentGraph.prisma") as mock_prisma,
patch(
"backend.data.graph.get_sub_graphs",
new_callable=AsyncMock,
@@ -84,10 +97,239 @@ async def test_admin_download_pending_agent_with_subagents() -> None:
for_export=True,
)
assert result is not None, "Admin export of pending agent must succeed"
assert result is mock_graph_model
mock_get_sub.assert_awaited_once_with(mock_graph)
mock_from_db.assert_called_once_with(
graph=mock_graph,
sub_graphs=[mock_sub_graph],
for_export=True,
)
# ---- SECRT-2167 security: admin endpoints reject non-admin users ---- #
app = fastapi.FastAPI()
app.include_router(store_admin_router)
@app.exception_handler(NotFoundError)
async def _not_found_handler(
request: fastapi.Request, exc: NotFoundError
) -> fastapi.responses.JSONResponse:
return fastapi.responses.JSONResponse(status_code=404, content={"detail": str(exc)})
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
"""Setup admin auth overrides for all route tests in this module."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def test_preview_requires_admin(mock_jwt_user) -> None:
"""Non-admin users must get 403 on the preview endpoint."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
response = client.get(f"/admin/submissions/{SLV_ID}/preview")
assert response.status_code == 403
def test_add_to_library_requires_admin(mock_jwt_user) -> None:
"""Non-admin users must get 403 on the add-to-library endpoint."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
response = client.post(f"/admin/submissions/{SLV_ID}/add-to-library")
assert response.status_code == 403
def test_preview_nonexistent_submission(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Preview of a nonexistent submission returns 404."""
mocker.patch(
"backend.api.features.admin.store_admin_routes.store_db"
".get_store_agent_details_as_admin",
side_effect=NotFoundError("not found"),
)
response = client.get(f"/admin/submissions/{SLV_ID}/preview")
assert response.status_code == 404
# ---- SECRT-2167 bypass: verify the right data sources are used ---- #
@pytest.mark.asyncio
async def test_preview_queries_store_listing_version_not_store_agent() -> None:
"""get_store_agent_details_as_admin must query StoreListingVersion
directly (not the APPROVED-only StoreAgent view). This is THE test that
prevents the bypass from being accidentally reverted."""
from backend.api.features.store.db import get_store_agent_details_as_admin
mock_slv = MagicMock()
mock_slv.id = SLV_ID
mock_slv.name = "Test Agent"
mock_slv.subHeading = "Short desc"
mock_slv.description = "Long desc"
mock_slv.videoUrl = None
mock_slv.agentOutputDemoUrl = None
mock_slv.imageUrls = ["https://example.com/img.png"]
mock_slv.instructions = None
mock_slv.categories = ["productivity"]
mock_slv.version = 1
mock_slv.agentGraphId = GRAPH_ID
mock_slv.agentGraphVersion = GRAPH_VERSION
mock_slv.updatedAt = datetime(2026, 3, 24, tzinfo=timezone.utc)
mock_slv.recommendedScheduleCron = "0 9 * * *"
mock_listing = MagicMock()
mock_listing.id = "listing-id"
mock_listing.slug = "test-agent"
mock_listing.activeVersionId = SLV_ID
mock_listing.hasApprovedVersion = False
mock_listing.CreatorProfile = MagicMock(username="creator", avatarUrl="")
mock_slv.StoreListing = mock_listing
with (
patch(
"backend.api.features.store.db.prisma.models" ".StoreListingVersion.prisma",
) as mock_slv_prisma,
patch(
"backend.api.features.store.db.prisma.models.StoreAgent.prisma",
) as mock_store_agent_prisma,
):
mock_slv_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
result = await get_store_agent_details_as_admin(SLV_ID)
# Verify it queried StoreListingVersion (not the APPROVED-only StoreAgent)
mock_slv_prisma.return_value.find_unique.assert_awaited_once()
await_args = mock_slv_prisma.return_value.find_unique.await_args
assert await_args is not None
assert await_args.kwargs["where"] == {"id": SLV_ID}
# Verify the APPROVED-only StoreAgent view was NOT touched
mock_store_agent_prisma.assert_not_called()
# Verify the result has the right data
assert result.agent_name == "Test Agent"
assert result.agent_image == ["https://example.com/img.png"]
assert result.has_approved_version is False
assert result.runs == 0
assert result.rating == 0.0
@pytest.mark.asyncio
async def test_resolve_graph_admin_uses_get_graph_as_admin() -> None:
"""resolve_graph_for_library(admin=True) must call get_graph_as_admin,
not get_graph. This is THE test that prevents the add-to-library bypass
from being accidentally reverted."""
from backend.api.features.library._add_to_library import resolve_graph_for_library
mock_slv = MagicMock()
mock_slv.AgentGraph = MagicMock(id=GRAPH_ID, version=GRAPH_VERSION)
mock_graph_model = MagicMock(name="GraphModel")
with (
patch(
"backend.api.features.library._add_to_library.prisma.models"
".StoreListingVersion.prisma",
) as mock_prisma,
patch(
"backend.api.features.library._add_to_library.graph_db"
".get_graph_as_admin",
new_callable=AsyncMock,
return_value=mock_graph_model,
) as mock_admin,
patch(
"backend.api.features.library._add_to_library.graph_db.get_graph",
new_callable=AsyncMock,
) as mock_regular,
):
mock_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
result = await resolve_graph_for_library(SLV_ID, ADMIN_USER_ID, admin=True)
assert result is mock_graph_model
mock_admin.assert_awaited_once_with(
graph_id=GRAPH_ID, version=GRAPH_VERSION, user_id=ADMIN_USER_ID
)
mock_regular.assert_not_awaited()
@pytest.mark.asyncio
async def test_resolve_graph_regular_uses_get_graph() -> None:
"""resolve_graph_for_library(admin=False) must call get_graph,
not get_graph_as_admin. Ensures the non-admin path is preserved."""
from backend.api.features.library._add_to_library import resolve_graph_for_library
mock_slv = MagicMock()
mock_slv.AgentGraph = MagicMock(id=GRAPH_ID, version=GRAPH_VERSION)
mock_graph_model = MagicMock(name="GraphModel")
with (
patch(
"backend.api.features.library._add_to_library.prisma.models"
".StoreListingVersion.prisma",
) as mock_prisma,
patch(
"backend.api.features.library._add_to_library.graph_db"
".get_graph_as_admin",
new_callable=AsyncMock,
) as mock_admin,
patch(
"backend.api.features.library._add_to_library.graph_db.get_graph",
new_callable=AsyncMock,
return_value=mock_graph_model,
) as mock_regular,
):
mock_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
result = await resolve_graph_for_library(SLV_ID, "regular-user-id", admin=False)
assert result is mock_graph_model
mock_regular.assert_awaited_once_with(
graph_id=GRAPH_ID, version=GRAPH_VERSION, user_id="regular-user-id"
)
mock_admin.assert_not_awaited()
# ---- Library membership grants graph access (product decision) ---- #
@pytest.mark.asyncio
async def test_library_member_can_view_pending_agent_in_builder() -> None:
"""After adding a pending agent to their library, the user should be
able to load the graph in the builder via get_graph()."""
mock_graph = _make_mock_graph()
mock_graph_model = MagicMock(name="GraphModel")
mock_library_agent = MagicMock()
mock_library_agent.AgentGraph = mock_graph
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch(
"backend.data.graph.StoreListingVersion.prisma",
) as mock_slv_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
patch(
"backend.data.graph.GraphModel.from_db",
return_value=mock_graph_model,
),
):
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
mock_lib_prisma.return_value.find_first = AsyncMock(
return_value=mock_library_agent
)
from backend.data.graph import get_graph
result = await get_graph(
graph_id=GRAPH_ID,
version=GRAPH_VERSION,
user_id=ADMIN_USER_ID,
)
assert result is mock_graph_model, "Library membership should grant graph access"

View File

@@ -30,8 +30,14 @@ from backend.copilot.model import (
from backend.copilot.rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
acquire_reset_lock,
check_rate_limit,
get_daily_reset_count,
get_global_rate_limits,
get_usage_status,
increment_daily_reset_count,
release_reset_lock,
reset_daily_usage,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
@@ -59,9 +65,16 @@ from backend.copilot.tools.models import (
UnderstandingUpdatedResponse,
)
from backend.copilot.tracking import track_user_message
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.redis_client import get_redis_async
from backend.data.understanding import get_business_understanding
from backend.data.workspace import get_or_create_workspace
from backend.util.exceptions import NotFoundError
from backend.util.exceptions import InsufficientBalanceError, NotFoundError
from backend.util.settings import Settings
settings = Settings()
logger = logging.getLogger(__name__)
config = ChatConfig()
@@ -69,8 +82,6 @@ _UUID_RE = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
)
logger = logging.getLogger(__name__)
async def _validate_and_get_session(
session_id: str,
@@ -421,11 +432,187 @@ async def get_copilot_usage(
"""Get CoPilot usage status for the authenticated user.
Returns current token usage vs limits for daily and weekly windows.
Global defaults sourced from LaunchDarkly (falling back to config).
"""
daily_limit, weekly_limit = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
return await get_usage_status(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
rate_limit_reset_cost=config.rate_limit_reset_cost,
)
class RateLimitResetResponse(BaseModel):
"""Response from resetting the daily rate limit."""
success: bool
credits_charged: int = Field(description="Credits charged (in cents)")
remaining_balance: int = Field(description="Credit balance after charge (in cents)")
usage: CoPilotUsageStatus = Field(description="Updated usage status after reset")
@router.post(
"/usage/reset",
status_code=200,
responses={
400: {
"description": "Bad Request (feature disabled or daily limit not reached)"
},
402: {"description": "Payment Required (insufficient credits)"},
429: {
"description": "Too Many Requests (max daily resets exceeded or reset in progress)"
},
503: {
"description": "Service Unavailable (Redis reset failed; credits refunded or support needed)"
},
},
)
async def reset_copilot_usage(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> RateLimitResetResponse:
"""Reset the daily CoPilot rate limit by spending credits.
Allows users who have hit their daily token limit to spend credits
to reset their daily usage counter and continue working.
Returns 400 if the feature is disabled or the user is not over the limit.
Returns 402 if the user has insufficient credits.
"""
cost = config.rate_limit_reset_cost
if cost <= 0:
raise HTTPException(
status_code=400,
detail="Rate limit reset is not available.",
)
if not settings.config.enable_credit:
raise HTTPException(
status_code=400,
detail="Rate limit reset is not available (credit system is disabled).",
)
daily_limit, weekly_limit = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
if daily_limit <= 0:
raise HTTPException(
status_code=400,
detail="No daily limit is configured — nothing to reset.",
)
# Check max daily resets. get_daily_reset_count returns None when Redis
# is unavailable; reject the reset in that case to prevent unlimited
# free resets when the counter store is down.
reset_count = await get_daily_reset_count(user_id)
if reset_count is None:
raise HTTPException(
status_code=503,
detail="Unable to verify reset eligibility — please try again later.",
)
if config.max_daily_resets > 0 and reset_count >= config.max_daily_resets:
raise HTTPException(
status_code=429,
detail=f"You've used all {config.max_daily_resets} resets for today.",
)
# Acquire a per-user lock to prevent TOCTOU races (concurrent resets).
if not await acquire_reset_lock(user_id):
raise HTTPException(
status_code=429,
detail="A reset is already in progress. Please try again.",
)
try:
# Verify the user is actually at or over their daily limit.
usage_status = await get_usage_status(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
)
if daily_limit > 0 and usage_status.daily.used < daily_limit:
raise HTTPException(
status_code=400,
detail="You have not reached your daily limit yet.",
)
# If the weekly limit is also exhausted, resetting the daily counter
# won't help — the user would still be blocked by the weekly limit.
if weekly_limit > 0 and usage_status.weekly.used >= weekly_limit:
raise HTTPException(
status_code=400,
detail="Your weekly limit is also reached. Resetting the daily limit won't help.",
)
# Charge credits.
credit_model = await get_user_credit_model(user_id)
try:
remaining = await credit_model.spend_credits(
user_id=user_id,
cost=cost,
metadata=UsageTransactionMetadata(
reason="CoPilot daily rate limit reset",
),
)
except InsufficientBalanceError as e:
raise HTTPException(
status_code=402,
detail="Insufficient credits to reset your rate limit.",
) from e
# Reset daily usage in Redis. If this fails, refund the credits
# so the user is not charged for a service they did not receive.
if not await reset_daily_usage(user_id, daily_token_limit=daily_limit):
# Compensate: refund the charged credits.
refunded = False
try:
await credit_model.top_up_credits(user_id, cost)
refunded = True
logger.warning(
"Refunded %d credits to user %s after Redis reset failure",
cost,
user_id[:8],
)
except Exception:
logger.error(
"CRITICAL: Failed to refund %d credits to user %s "
"after Redis reset failure — manual intervention required",
cost,
user_id[:8],
exc_info=True,
)
if refunded:
raise HTTPException(
status_code=503,
detail="Rate limit reset failed — please try again later. "
"Your credits have not been charged.",
)
raise HTTPException(
status_code=503,
detail="Rate limit reset failed and the automatic refund "
"also failed. Please contact support for assistance.",
)
# Track the reset count for daily cap enforcement.
await increment_daily_reset_count(user_id)
finally:
await release_reset_lock(user_id)
# Return updated usage status.
updated_usage = await get_usage_status(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
rate_limit_reset_cost=config.rate_limit_reset_cost,
)
return RateLimitResetResponse(
success=True,
credits_charged=cost,
remaining_balance=remaining,
usage=updated_usage,
)
@@ -526,12 +713,16 @@ async def stream_chat_post(
# Pre-turn rate limit check (token-based).
# check_rate_limit short-circuits internally when both limits are 0.
# Global defaults sourced from LaunchDarkly, falling back to config.
if user_id:
try:
daily_limit, weekly_limit = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
await check_rate_limit(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
)
except RateLimitExceeded as e:
raise HTTPException(status_code=429, detail=str(e)) from e
@@ -894,6 +1085,47 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Suggested Prompts ==========
class SuggestedTheme(BaseModel):
"""A themed group of suggested prompts."""
name: str
prompts: list[str]
class SuggestedPromptsResponse(BaseModel):
"""Response model for user-specific suggested prompts grouped by theme."""
themes: list[SuggestedTheme]
@router.get(
"/suggested-prompts",
dependencies=[Security(auth.requires_user)],
)
async def get_suggested_prompts(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> SuggestedPromptsResponse:
"""
Get LLM-generated suggested prompts grouped by theme.
Returns personalized quick-action prompts based on the user's
business understanding. Returns empty themes list if no custom
prompts are available.
"""
understanding = await get_business_understanding(user_id)
if understanding is None or not understanding.suggested_prompts:
return SuggestedPromptsResponse(themes=[])
themes = [
SuggestedTheme(name=name, prompts=prompts)
for name, prompts in understanding.suggested_prompts.items()
]
return SuggestedPromptsResponse(themes=themes)
# ========== Configuration ==========

View File

@@ -1,7 +1,7 @@
"""Tests for chat API routes: session title update, file attachment validation, usage, and rate limiting."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, MagicMock
import fastapi
import fastapi.testclient
@@ -368,6 +368,7 @@ def test_usage_returns_daily_and_weekly(
user_id=test_user_id,
daily_token_limit=10000,
weekly_token_limit=50000,
rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost,
)
@@ -380,6 +381,7 @@ def test_usage_uses_config_limits(
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 500)
response = client.get("/usage")
@@ -388,6 +390,7 @@ def test_usage_uses_config_limits(
user_id=test_user_id,
daily_token_limit=99999,
weekly_token_limit=77777,
rate_limit_reset_cost=500,
)
@@ -400,3 +403,69 @@ def test_usage_rejects_unauthenticated_request() -> None:
response = unauthenticated_client.get("/usage")
assert response.status_code == 401
# ─── Suggested prompts endpoint ──────────────────────────────────────
def _mock_get_business_understanding(
mocker: pytest_mock.MockerFixture,
*,
return_value=None,
):
"""Mock get_business_understanding."""
return mocker.patch(
"backend.api.features.chat.routes.get_business_understanding",
new_callable=AsyncMock,
return_value=return_value,
)
def test_suggested_prompts_returns_themes(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with themed prompts gets them back as themes list."""
mock_understanding = MagicMock()
mock_understanding.suggested_prompts = {
"Learn": ["L1", "L2"],
"Create": ["C1"],
}
_mock_get_business_understanding(mocker, return_value=mock_understanding)
response = client.get("/suggested-prompts")
assert response.status_code == 200
data = response.json()
assert "themes" in data
themes_by_name = {t["name"]: t["prompts"] for t in data["themes"]}
assert themes_by_name["Learn"] == ["L1", "L2"]
assert themes_by_name["Create"] == ["C1"]
def test_suggested_prompts_no_understanding(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with no understanding gets empty themes list."""
_mock_get_business_understanding(mocker, return_value=None)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"themes": []}
def test_suggested_prompts_empty_prompts(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with understanding but empty prompts gets empty themes list."""
mock_understanding = MagicMock()
mock_understanding.suggested_prompts = {}
_mock_get_business_understanding(mocker, return_value=mock_understanding)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"themes": []}

View File

@@ -0,0 +1,13 @@
"""Override session-scoped fixtures so unit tests run without the server."""
import pytest
@pytest.fixture(scope="session")
def server():
yield None
@pytest.fixture(scope="session", autouse=True)
def graph_cleanup():
yield

View File

@@ -34,16 +34,21 @@ from backend.data.model import (
HostScopedCredentials,
OAuth2Credentials,
UserIntegrations,
is_sdk_default,
)
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
from backend.integrations.credentials_store import provider_matches
from backend.integrations.credentials_store import (
is_system_credential,
provider_matches,
)
from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
create_mcp_oauth_handler,
)
from backend.integrations.managed_credentials import ensure_managed_credentials
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
@@ -109,6 +114,7 @@ class CredentialsMetaResponse(BaseModel):
default=None,
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
)
is_managed: bool = False
@model_validator(mode="before")
@classmethod
@@ -138,6 +144,19 @@ class CredentialsMetaResponse(BaseModel):
return None
def to_meta_response(cred: Credentials) -> CredentialsMetaResponse:
return CredentialsMetaResponse(
id=cred.id,
provider=cred.provider,
type=cred.type,
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred),
is_managed=cred.is_managed,
)
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
async def callback(
provider: Annotated[
@@ -204,34 +223,20 @@ async def callback(
f"and provider {provider.value}"
)
return CredentialsMetaResponse(
id=credentials.id,
provider=credentials.provider,
type=credentials.type,
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=(CredentialsMetaResponse.get_host(credentials)),
)
return to_meta_response(credentials)
@router.get("/credentials", summary="List Credentials")
async def list_credentials(
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
# Fire-and-forget: provision missing managed credentials in the background.
# The credential appears on the next page load; listing is never blocked.
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
credentials = await creds_manager.store.get_all_creds(user_id)
return [
CredentialsMetaResponse(
id=cred.id,
provider=cred.provider,
type=cred.type,
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred),
)
for cred in credentials
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
]
@@ -242,19 +247,11 @@ async def list_credentials_by_provider(
],
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
return [
CredentialsMetaResponse(
id=cred.id,
provider=cred.provider,
type=cred.type,
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred),
)
for cred in credentials
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
]
@@ -267,18 +264,21 @@ async def get_credential(
],
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
user_id: Annotated[str, Security(get_user_id)],
) -> Credentials:
) -> CredentialsMetaResponse:
if is_sdk_default(cred_id):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
credential = await creds_manager.get(user_id, cred_id)
if not credential:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if credential.provider != provider:
if not provider_matches(credential.provider, provider):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Credentials do not match the specified provider",
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
return credential
return to_meta_response(credential)
@router.post("/{provider}/credentials", status_code=201, summary="Create Credentials")
@@ -288,16 +288,22 @@ async def create_credentials(
ProviderName, Path(title="The provider to create credentials for")
],
credentials: Credentials,
) -> Credentials:
) -> CredentialsMetaResponse:
if is_sdk_default(credentials.id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Cannot create credentials with a reserved ID",
)
credentials.provider = provider
try:
await creds_manager.create(user_id, credentials)
except Exception as e:
except Exception:
logger.exception("Failed to store credentials")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to store credentials: {str(e)}",
detail="Failed to store credentials",
)
return credentials
return to_meta_response(credentials)
class CredentialsDeletionResponse(BaseModel):
@@ -332,15 +338,29 @@ async def delete_credentials(
bool, Query(title="Whether to proceed if any linked webhooks are still in use")
] = False,
) -> CredentialsDeletionResponse | CredentialsDeletionNeedsConfirmationResponse:
if is_sdk_default(cred_id):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if is_system_credential(cred_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="System-managed credentials cannot be deleted",
)
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
if not creds:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if creds.provider != provider:
if not provider_matches(creds.provider, provider):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Credentials do not match the specified provider",
detail="Credentials not found",
)
if creds.is_managed:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="AutoGPT-managed credentials cannot be deleted",
)
try:

View File

@@ -0,0 +1,570 @@
"""Tests for credentials API security: no secret leakage, SDK defaults filtered."""
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import fastapi
import fastapi.testclient
import pytest
from pydantic import SecretStr
from backend.api.features.integrations.router import router
from backend.data.model import (
APIKeyCredentials,
HostScopedCredentials,
OAuth2Credentials,
UserPasswordCredentials,
)
app = fastapi.FastAPI()
app.include_router(router)
client = fastapi.testclient.TestClient(app)
TEST_USER_ID = "test-user-id"
def _make_api_key_cred(cred_id: str = "cred-123", provider: str = "openai"):
return APIKeyCredentials(
id=cred_id,
provider=provider,
title="My API Key",
api_key=SecretStr("sk-secret-key-value"),
)
def _make_oauth2_cred(cred_id: str = "cred-456", provider: str = "github"):
return OAuth2Credentials(
id=cred_id,
provider=provider,
title="My OAuth",
access_token=SecretStr("ghp_secret_token"),
refresh_token=SecretStr("ghp_refresh_secret"),
scopes=["repo", "user"],
username="testuser",
)
def _make_user_password_cred(cred_id: str = "cred-789", provider: str = "openai"):
return UserPasswordCredentials(
id=cred_id,
provider=provider,
title="My Login",
username=SecretStr("admin"),
password=SecretStr("s3cret-pass"),
)
def _make_host_scoped_cred(cred_id: str = "cred-host", provider: str = "openai"):
return HostScopedCredentials(
id=cred_id,
provider=provider,
title="Host Cred",
host="https://api.example.com",
headers={"Authorization": SecretStr("Bearer top-secret")},
)
def _make_sdk_default_cred(provider: str = "openai"):
return APIKeyCredentials(
id=f"{provider}-default",
provider=provider,
title=f"{provider} (default)",
api_key=SecretStr("sk-platform-secret-key"),
)
@pytest.fixture(autouse=True)
def setup_auth(mock_jwt_user):
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
class TestGetCredentialReturnsMetaOnly:
"""GET /{provider}/credentials/{cred_id} must not return secrets."""
def test_api_key_credential_no_secret(self):
cred = _make_api_key_cred()
with (
patch.object(router, "dependencies", []),
patch("backend.api.features.integrations.router.creds_manager") as mock_mgr,
):
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.get("/openai/credentials/cred-123")
assert resp.status_code == 200
data = resp.json()
assert data["id"] == "cred-123"
assert data["provider"] == "openai"
assert data["type"] == "api_key"
assert "api_key" not in data
assert "sk-secret-key-value" not in str(data)
def test_oauth2_credential_no_secret(self):
cred = _make_oauth2_cred()
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.get("/github/credentials/cred-456")
assert resp.status_code == 200
data = resp.json()
assert data["id"] == "cred-456"
assert data["scopes"] == ["repo", "user"]
assert data["username"] == "testuser"
assert "access_token" not in data
assert "refresh_token" not in data
assert "ghp_" not in str(data)
def test_user_password_credential_no_secret(self):
cred = _make_user_password_cred()
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.get("/openai/credentials/cred-789")
assert resp.status_code == 200
data = resp.json()
assert data["id"] == "cred-789"
assert "password" not in data
assert "username" not in data or data["username"] is None
assert "s3cret-pass" not in str(data)
assert "admin" not in str(data)
def test_host_scoped_credential_no_secret(self):
cred = _make_host_scoped_cred()
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.get("/openai/credentials/cred-host")
assert resp.status_code == 200
data = resp.json()
assert data["id"] == "cred-host"
assert data["host"] == "https://api.example.com"
assert "headers" not in data
assert "top-secret" not in str(data)
def test_get_credential_wrong_provider_returns_404(self):
"""Provider mismatch should return generic 404, not leak credential existence."""
cred = _make_api_key_cred(provider="openai")
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.get("/github/credentials/cred-123")
assert resp.status_code == 404
assert resp.json()["detail"] == "Credentials not found"
def test_list_credentials_no_secrets(self):
"""List endpoint must not leak secrets in any credential."""
creds = [_make_api_key_cred(), _make_oauth2_cred()]
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_all_creds = AsyncMock(return_value=creds)
resp = client.get("/credentials")
assert resp.status_code == 200
raw = str(resp.json())
assert "sk-secret-key-value" not in raw
assert "ghp_secret_token" not in raw
assert "ghp_refresh_secret" not in raw
class TestSdkDefaultCredentialsNotAccessible:
"""SDK default credentials (ID ending in '-default') must be hidden."""
def test_get_sdk_default_returns_404(self):
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock()
resp = client.get("/openai/credentials/openai-default")
assert resp.status_code == 404
mock_mgr.get.assert_not_called()
def test_list_credentials_excludes_sdk_defaults(self):
user_cred = _make_api_key_cred()
sdk_cred = _make_sdk_default_cred("openai")
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_all_creds = AsyncMock(return_value=[user_cred, sdk_cred])
resp = client.get("/credentials")
assert resp.status_code == 200
data = resp.json()
ids = [c["id"] for c in data]
assert "cred-123" in ids
assert "openai-default" not in ids
def test_list_by_provider_excludes_sdk_defaults(self):
user_cred = _make_api_key_cred()
sdk_cred = _make_sdk_default_cred("openai")
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_creds_by_provider = AsyncMock(
return_value=[user_cred, sdk_cred]
)
resp = client.get("/openai/credentials")
assert resp.status_code == 200
data = resp.json()
ids = [c["id"] for c in data]
assert "cred-123" in ids
assert "openai-default" not in ids
def test_delete_sdk_default_returns_404(self):
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_creds_by_id = AsyncMock()
resp = client.request("DELETE", "/openai/credentials/openai-default")
assert resp.status_code == 404
mock_mgr.store.get_creds_by_id.assert_not_called()
class TestCreateCredentialNoSecretInResponse:
"""POST /{provider}/credentials must not return secrets."""
def test_create_api_key_no_secret_in_response(self):
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.create = AsyncMock()
resp = client.post(
"/openai/credentials",
json={
"id": "new-cred",
"provider": "openai",
"type": "api_key",
"title": "New Key",
"api_key": "sk-newsecret",
},
)
assert resp.status_code == 201
data = resp.json()
assert data["id"] == "new-cred"
assert "api_key" not in data
assert "sk-newsecret" not in str(data)
def test_create_with_sdk_default_id_rejected(self):
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.create = AsyncMock()
resp = client.post(
"/openai/credentials",
json={
"id": "openai-default",
"provider": "openai",
"type": "api_key",
"title": "Sneaky",
"api_key": "sk-evil",
},
)
assert resp.status_code == 403
mock_mgr.create.assert_not_called()
class TestManagedCredentials:
"""AutoGPT-managed credentials cannot be deleted by users."""
def test_delete_is_managed_returns_403(self):
cred = APIKeyCredentials(
id="managed-cred-1",
provider="agent_mail",
title="AgentMail (managed by AutoGPT)",
api_key=SecretStr("sk-managed-key"),
is_managed=True,
)
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_creds_by_id = AsyncMock(return_value=cred)
resp = client.request("DELETE", "/agent_mail/credentials/managed-cred-1")
assert resp.status_code == 403
assert "AutoGPT-managed" in resp.json()["detail"]
def test_list_credentials_includes_is_managed_field(self):
managed = APIKeyCredentials(
id="managed-1",
provider="agent_mail",
title="AgentMail (managed)",
api_key=SecretStr("sk-key"),
is_managed=True,
)
regular = APIKeyCredentials(
id="regular-1",
provider="openai",
title="My Key",
api_key=SecretStr("sk-key"),
)
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_all_creds = AsyncMock(return_value=[managed, regular])
resp = client.get("/credentials")
assert resp.status_code == 200
data = resp.json()
managed_cred = next(c for c in data if c["id"] == "managed-1")
regular_cred = next(c for c in data if c["id"] == "regular-1")
assert managed_cred["is_managed"] is True
assert regular_cred["is_managed"] is False
# ---------------------------------------------------------------------------
# Managed credential provisioning infrastructure
# ---------------------------------------------------------------------------
def _make_managed_cred(
provider: str = "agent_mail", pod_id: str = "pod-abc"
) -> APIKeyCredentials:
return APIKeyCredentials(
id="managed-auto",
provider=provider,
title="AgentMail (managed by AutoGPT)",
api_key=SecretStr("sk-pod-key"),
is_managed=True,
metadata={"pod_id": pod_id},
)
def _make_store_mock(**kwargs) -> MagicMock:
"""Create a store mock with a working async ``locks()`` context manager."""
@asynccontextmanager
async def _noop_locked(key):
yield
locks_obj = MagicMock()
locks_obj.locked = _noop_locked
store = MagicMock(**kwargs)
store.locks = AsyncMock(return_value=locks_obj)
return store
class TestEnsureManagedCredentials:
"""Unit tests for the ensure/cleanup helpers in managed_credentials.py."""
@pytest.mark.asyncio
async def test_provisions_when_missing(self):
"""Provider.provision() is called when no managed credential exists."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock(return_value=cred)
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=False)
store.add_managed_credential = AsyncMock()
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_awaited_once_with("user-1")
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
@pytest.mark.asyncio
async def test_skips_when_already_exists(self):
"""Provider.provision() is NOT called when managed credential exists."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock()
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=True)
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_not_awaited()
@pytest.mark.asyncio
async def test_skips_when_unavailable(self):
"""Provider.provision() is NOT called when provider is not available."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=False)
provider.provision = AsyncMock()
store = _make_store_mock()
store.has_managed_credential = AsyncMock()
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_not_awaited()
store.has_managed_credential.assert_not_awaited()
@pytest.mark.asyncio
async def test_provision_failure_does_not_propagate(self):
"""A failed provision is logged but does not raise."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock(side_effect=RuntimeError("boom"))
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=False)
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
# No exception raised — provisioning failure is swallowed.
class TestCleanupManagedCredentials:
"""Unit tests for cleanup_managed_credentials."""
@pytest.mark.asyncio
async def test_calls_deprovision_for_managed_creds(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "agent_mail"
provider.deprovision = AsyncMock()
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[cred])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["agent_mail"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
provider.deprovision.assert_awaited_once_with("user-1", cred)
@pytest.mark.asyncio
async def test_skips_non_managed_creds(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
regular = _make_api_key_cred()
provider = MagicMock()
provider.provider_name = "openai"
provider.deprovision = AsyncMock()
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[regular])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["openai"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
provider.deprovision.assert_not_awaited()
@pytest.mark.asyncio
async def test_deprovision_failure_does_not_propagate(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "agent_mail"
provider.deprovision = AsyncMock(side_effect=RuntimeError("boom"))
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[cred])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["agent_mail"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
# No exception raised — cleanup failure is swallowed.

View File

@@ -0,0 +1,120 @@
"""Shared logic for adding store agents to a user's library.
Both `add_store_agent_to_library` and `add_store_agent_to_library_as_admin`
delegate to these helpers so the duplication-prone create/restore/dedup
logic lives in exactly one place.
"""
import logging
import prisma.errors
import prisma.models
import backend.api.features.library.model as library_model
import backend.data.graph as graph_db
from backend.data.graph import GraphModel, GraphSettings
from backend.data.includes import library_agent_include
from backend.util.exceptions import NotFoundError
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
async def resolve_graph_for_library(
store_listing_version_id: str,
user_id: str,
*,
admin: bool,
) -> GraphModel:
"""Look up a StoreListingVersion and resolve its graph.
When ``admin=True``, uses ``get_graph_as_admin`` to bypass the marketplace
APPROVED-only check. Otherwise uses the regular ``get_graph``.
"""
slv = await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}, include={"AgentGraph": True}
)
if not slv or not slv.AgentGraph:
raise NotFoundError(
f"Store listing version {store_listing_version_id} not found or invalid"
)
ag = slv.AgentGraph
if admin:
graph_model = await graph_db.get_graph_as_admin(
graph_id=ag.id, version=ag.version, user_id=user_id
)
else:
graph_model = await graph_db.get_graph(
graph_id=ag.id, version=ag.version, user_id=user_id
)
if not graph_model:
raise NotFoundError(f"Graph #{ag.id} v{ag.version} not found or accessible")
return graph_model
async def add_graph_to_library(
store_listing_version_id: str,
graph_model: GraphModel,
user_id: str,
) -> library_model.LibraryAgent:
"""Check existing / restore soft-deleted / create new LibraryAgent.
Uses a create-then-catch-UniqueViolationError-then-update pattern on
the (userId, agentGraphId, agentGraphVersion) composite unique constraint.
This is more robust than ``upsert`` because Prisma's upsert atomicity
guarantees are not well-documented for all versions.
"""
settings_json = SafeJson(GraphSettings.from_graph(graph_model).model_dump())
_include = library_agent_include(
user_id, include_nodes=False, include_executions=False
)
try:
added_agent = await prisma.models.LibraryAgent.prisma().create(
data={
"User": {"connect": {"id": user_id}},
"AgentGraph": {
"connect": {
"graphVersionId": {
"id": graph_model.id,
"version": graph_model.version,
}
}
},
"isCreatedByUser": False,
"useGraphIsActiveVersion": False,
"settings": settings_json,
},
include=_include,
)
except prisma.errors.UniqueViolationError:
# Already exists — update to restore if previously soft-deleted/archived
added_agent = await prisma.models.LibraryAgent.prisma().update(
where={
"userId_agentGraphId_agentGraphVersion": {
"userId": user_id,
"agentGraphId": graph_model.id,
"agentGraphVersion": graph_model.version,
}
},
data={
"isDeleted": False,
"isArchived": False,
"settings": settings_json,
},
include=_include,
)
if added_agent is None:
raise NotFoundError(
f"LibraryAgent for graph #{graph_model.id} "
f"v{graph_model.version} not found after UniqueViolationError"
)
logger.debug(
f"Added graph #{graph_model.id} v{graph_model.version} "
f"for store listing version #{store_listing_version_id} "
f"to library for user #{user_id}"
)
return library_model.LibraryAgent.from_db(added_agent)

View File

@@ -0,0 +1,80 @@
from unittest.mock import AsyncMock, MagicMock, patch
import prisma.errors
import pytest
from ._add_to_library import add_graph_to_library
@pytest.mark.asyncio
async def test_add_graph_to_library_create_new_agent() -> None:
"""When no matching LibraryAgent exists, create inserts a new one."""
graph_model = MagicMock(id="graph-id", version=2, nodes=[])
created_agent = MagicMock(name="CreatedLibraryAgent")
converted_agent = MagicMock(name="ConvertedLibraryAgent")
with (
patch(
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
) as mock_prisma,
patch(
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
return_value=converted_agent,
) as mock_from_db,
):
mock_prisma.return_value.create = AsyncMock(return_value=created_agent)
result = await add_graph_to_library("slv-id", graph_model, "user-id")
assert result is converted_agent
mock_from_db.assert_called_once_with(created_agent)
# Verify create was called with correct data
create_call = mock_prisma.return_value.create.call_args
create_data = create_call.kwargs["data"]
assert create_data["User"] == {"connect": {"id": "user-id"}}
assert create_data["AgentGraph"] == {
"connect": {"graphVersionId": {"id": "graph-id", "version": 2}}
}
assert create_data["isCreatedByUser"] is False
assert create_data["useGraphIsActiveVersion"] is False
@pytest.mark.asyncio
async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
"""UniqueViolationError on create falls back to update."""
graph_model = MagicMock(id="graph-id", version=2, nodes=[])
updated_agent = MagicMock(name="UpdatedLibraryAgent")
converted_agent = MagicMock(name="ConvertedLibraryAgent")
with (
patch(
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
) as mock_prisma,
patch(
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
return_value=converted_agent,
) as mock_from_db,
):
mock_prisma.return_value.create = AsyncMock(
side_effect=prisma.errors.UniqueViolationError(
MagicMock(), message="unique constraint"
)
)
mock_prisma.return_value.update = AsyncMock(return_value=updated_agent)
result = await add_graph_to_library("slv-id", graph_model, "user-id")
assert result is converted_agent
mock_from_db.assert_called_once_with(updated_agent)
# Verify update was called with correct where and data
update_call = mock_prisma.return_value.update.call_args
assert update_call.kwargs["where"] == {
"userId_agentGraphId_agentGraphVersion": {
"userId": "user-id",
"agentGraphId": "graph-id",
"agentGraphVersion": 2,
}
}
update_data = update_call.kwargs["data"]
assert update_data["isDeleted"] is False
assert update_data["isArchived"] is False

View File

@@ -336,12 +336,15 @@ async def get_library_agent_by_graph_id(
user_id: str,
graph_id: str,
graph_version: Optional[int] = None,
include_archived: bool = False,
) -> library_model.LibraryAgent | None:
filter: prisma.types.LibraryAgentWhereInput = {
"agentGraphId": graph_id,
"userId": user_id,
"isDeleted": False,
}
if not include_archived:
filter["isArchived"] = False
if graph_version is not None:
filter["agentGraphVersion"] = graph_version
@@ -433,32 +436,53 @@ async def create_library_agent(
async with transaction() as tx:
library_agents = await asyncio.gather(
*(
prisma.models.LibraryAgent.prisma(tx).create(
data=prisma.types.LibraryAgentCreateInput(
isCreatedByUser=(user_id == user_id),
useGraphIsActiveVersion=True,
User={"connect": {"id": user_id}},
AgentGraph={
"connect": {
"graphVersionId": {
"id": graph_entry.id,
"version": graph_entry.version,
prisma.models.LibraryAgent.prisma(tx).upsert(
where={
"userId_agentGraphId_agentGraphVersion": {
"userId": user_id,
"agentGraphId": graph_entry.id,
"agentGraphVersion": graph_entry.version,
}
},
data={
"create": prisma.types.LibraryAgentCreateInput(
isCreatedByUser=(user_id == graph.user_id),
useGraphIsActiveVersion=True,
User={"connect": {"id": user_id}},
AgentGraph={
"connect": {
"graphVersionId": {
"id": graph_entry.id,
"version": graph_entry.version,
}
}
}
},
settings=SafeJson(
GraphSettings.from_graph(
graph_entry,
hitl_safe_mode=hitl_safe_mode,
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
**(
{"Folder": {"connect": {"id": folder_id}}}
if folder_id and graph_entry is graph
else {}
),
),
"update": {
"isDeleted": False,
"isArchived": False,
"useGraphIsActiveVersion": True,
"settings": SafeJson(
GraphSettings.from_graph(
graph_entry,
hitl_safe_mode=hitl_safe_mode,
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
},
settings=SafeJson(
GraphSettings.from_graph(
graph_entry,
hitl_safe_mode=hitl_safe_mode,
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
**(
{"Folder": {"connect": {"id": folder_id}}}
if folder_id and graph_entry is graph
else {}
),
),
},
include=library_agent_include(
user_id, include_nodes=False, include_executions=False
),
@@ -582,7 +606,9 @@ async def update_graph_in_library(
created_graph = await graph_db.create_graph(graph_model, user_id)
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
library_agent = await get_library_agent_by_graph_id(
user_id, created_graph.id, include_archived=True
)
if not library_agent:
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
@@ -818,92 +844,38 @@ async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
async def add_store_agent_to_library(
store_listing_version_id: str, user_id: str
) -> library_model.LibraryAgent:
"""Adds a marketplace agent to the users library.
See also: `add_store_agent_to_library_as_admin()` which uses
`get_graph_as_admin` to bypass marketplace status checks for admin review.
"""
Adds an agent from a store listing version to the user's library if they don't already have it.
from ._add_to_library import add_graph_to_library, resolve_graph_for_library
Args:
store_listing_version_id: The ID of the store listing version containing the agent.
user_id: The users library to which the agent is being added.
Returns:
The newly created LibraryAgent if successfully added, the existing corresponding one if any.
Raises:
NotFoundError: If the store listing or associated agent is not found.
DatabaseError: If there's an issue creating the LibraryAgent record.
"""
logger.debug(
f"Adding agent from store listing version #{store_listing_version_id} "
f"to library for user #{user_id}"
)
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}, include={"AgentGraph": True}
)
graph_model = await resolve_graph_for_library(
store_listing_version_id, user_id, admin=False
)
if not store_listing_version or not store_listing_version.AgentGraph:
logger.warning(f"Store listing version not found: {store_listing_version_id}")
raise NotFoundError(
f"Store listing version {store_listing_version_id} not found or invalid"
)
return await add_graph_to_library(store_listing_version_id, graph_model, user_id)
graph = store_listing_version.AgentGraph
# Convert to GraphModel to check for HITL blocks
graph_model = await graph_db.get_graph(
graph_id=graph.id,
version=graph.version,
user_id=user_id,
include_subgraphs=False,
async def add_store_agent_to_library_as_admin(
store_listing_version_id: str, user_id: str
) -> library_model.LibraryAgent:
"""Admin variant that uses `get_graph_as_admin` to bypass marketplace
APPROVED-only checks, allowing admins to add pending agents for review."""
from ._add_to_library import add_graph_to_library, resolve_graph_for_library
logger.warning(
f"ADMIN adding agent from store listing version "
f"#{store_listing_version_id} to library for user #{user_id}"
)
if not graph_model:
raise NotFoundError(
f"Graph #{graph.id} v{graph.version} not found or accessible"
)
# Check if user already has this agent (non-deleted)
if existing := await get_library_agent_by_graph_id(
user_id, graph.id, graph.version
):
return existing
# Check for soft-deleted version and restore it
deleted_agent = await prisma.models.LibraryAgent.prisma().find_unique(
where={
"userId_agentGraphId_agentGraphVersion": {
"userId": user_id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
}
},
graph_model = await resolve_graph_for_library(
store_listing_version_id, user_id, admin=True
)
if deleted_agent and deleted_agent.isDeleted:
return await update_library_agent(deleted_agent.id, user_id, is_deleted=False)
# Create LibraryAgent entry
added_agent = await prisma.models.LibraryAgent.prisma().create(
data={
"User": {"connect": {"id": user_id}},
"AgentGraph": {
"connect": {
"graphVersionId": {"id": graph.id, "version": graph.version}
}
},
"isCreatedByUser": False,
"useGraphIsActiveVersion": False,
"settings": SafeJson(GraphSettings.from_graph(graph_model).model_dump()),
},
include=library_agent_include(
user_id, include_nodes=False, include_executions=False
),
)
logger.debug(
f"Added graph #{graph.id} v{graph.version}"
f"for store listing version #{store_listing_version.id} "
f"to library for user #{user_id}"
)
return library_model.LibraryAgent.from_db(added_agent)
return await add_graph_to_library(store_listing_version_id, graph_model, user_id)
##############################################

View File

@@ -1,4 +1,6 @@
from contextlib import asynccontextmanager
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import prisma.enums
import prisma.models
@@ -85,10 +87,6 @@ async def test_get_library_agents(mocker):
async def test_add_agent_to_library(mocker):
await connect()
# Mock the transaction context
mock_transaction = mocker.patch("backend.api.features.library.db.transaction")
mock_transaction.return_value.__aenter__ = mocker.AsyncMock(return_value=None)
mock_transaction.return_value.__aexit__ = mocker.AsyncMock(return_value=None)
# Mock data
mock_store_listing_data = prisma.models.StoreListingVersion(
id="version123",
@@ -143,15 +141,18 @@ async def test_add_agent_to_library(mocker):
)
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.create = mocker.AsyncMock(
return_value=mock_library_agent_data
)
# Mock graph_db.get_graph function that's called to check for HITL blocks
mock_graph_db = mocker.patch("backend.api.features.library.db.graph_db")
# Mock graph_db.get_graph function that's called in resolve_graph_for_library
# (lives in _add_to_library.py after refactor, not db.py)
mock_graph_db = mocker.patch(
"backend.api.features.library._add_to_library.graph_db"
)
mock_graph_model = mocker.Mock()
mock_graph_model.id = "agent1"
mock_graph_model.version = 1
mock_graph_model.nodes = (
[]
) # Empty list so _has_human_in_the_loop_blocks returns False
@@ -170,37 +171,27 @@ async def test_add_agent_to_library(mocker):
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
where={"id": "version123"}, include={"AgentGraph": True}
)
mock_library_agent.return_value.find_unique.assert_called_once_with(
where={
"userId_agentGraphId_agentGraphVersion": {
"userId": "test-user",
"agentGraphId": "agent1",
"agentGraphVersion": 1,
}
},
)
# Check that create was called with the expected data including settings
create_call_args = mock_library_agent.return_value.create.call_args
assert create_call_args is not None
# Verify the main structure
expected_data = {
# Verify the create data structure
create_data = create_call_args.kwargs["data"]
expected_create = {
"User": {"connect": {"id": "test-user"}},
"AgentGraph": {"connect": {"graphVersionId": {"id": "agent1", "version": 1}}},
"isCreatedByUser": False,
"useGraphIsActiveVersion": False,
}
actual_data = create_call_args[1]["data"]
# Check that all expected fields are present
for key, value in expected_data.items():
assert actual_data[key] == value
for key, value in expected_create.items():
assert create_data[key] == value
# Check that settings field is present and is a SafeJson object
assert "settings" in actual_data
assert hasattr(actual_data["settings"], "__class__") # Should be a SafeJson object
assert "settings" in create_data
assert hasattr(create_data["settings"], "__class__") # Should be a SafeJson object
# Check include parameter
assert create_call_args[1]["include"] == library_agent_include(
assert create_call_args.kwargs["include"] == library_agent_include(
"test-user", include_nodes=False, include_executions=False
)
@@ -224,3 +215,141 @@ async def test_add_agent_to_library_not_found(mocker):
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
where={"id": "version123"}, include={"AgentGraph": True}
)
@pytest.mark.asyncio
async def test_get_library_agent_by_graph_id_excludes_archived(mocker):
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
result = await db.get_library_agent_by_graph_id("test-user", "agent1", 7)
assert result is None
mock_library_agent.return_value.find_first.assert_called_once()
where = mock_library_agent.return_value.find_first.call_args.kwargs["where"]
assert where == {
"agentGraphId": "agent1",
"userId": "test-user",
"isDeleted": False,
"isArchived": False,
"agentGraphVersion": 7,
}
@pytest.mark.asyncio
async def test_get_library_agent_by_graph_id_can_include_archived(mocker):
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
result = await db.get_library_agent_by_graph_id(
"test-user",
"agent1",
7,
include_archived=True,
)
assert result is None
mock_library_agent.return_value.find_first.assert_called_once()
where = mock_library_agent.return_value.find_first.call_args.kwargs["where"]
assert where == {
"agentGraphId": "agent1",
"userId": "test-user",
"isDeleted": False,
"agentGraphVersion": 7,
}
@pytest.mark.asyncio
async def test_update_graph_in_library_allows_archived_library_agent(mocker):
graph = mocker.Mock(id="graph-id")
existing_version = mocker.Mock(version=1, is_active=True)
graph_model = mocker.Mock()
created_graph = mocker.Mock(id="graph-id", version=2, is_active=False)
current_library_agent = mocker.Mock()
updated_library_agent = mocker.Mock()
mocker.patch(
"backend.api.features.library.db.graph_db.get_graph_all_versions",
new=mocker.AsyncMock(return_value=[existing_version]),
)
mocker.patch(
"backend.api.features.library.db.graph_db.make_graph_model",
return_value=graph_model,
)
mocker.patch(
"backend.api.features.library.db.graph_db.create_graph",
new=mocker.AsyncMock(return_value=created_graph),
)
mock_get_library_agent = mocker.patch(
"backend.api.features.library.db.get_library_agent_by_graph_id",
new=mocker.AsyncMock(return_value=current_library_agent),
)
mock_update_library_agent = mocker.patch(
"backend.api.features.library.db.update_library_agent_version_and_settings",
new=mocker.AsyncMock(return_value=updated_library_agent),
)
result_graph, result_library_agent = await db.update_graph_in_library(
graph,
"test-user",
)
assert result_graph is created_graph
assert result_library_agent is updated_library_agent
assert graph.version == 2
graph_model.reassign_ids.assert_called_once_with(
user_id="test-user", reassign_graph_id=False
)
mock_get_library_agent.assert_awaited_once_with(
"test-user",
"graph-id",
include_archived=True,
)
mock_update_library_agent.assert_awaited_once_with("test-user", created_graph)
@pytest.mark.asyncio
async def test_create_library_agent_uses_upsert():
"""create_library_agent should use upsert (not create) to handle duplicates."""
mock_graph = MagicMock()
mock_graph.id = "graph-1"
mock_graph.version = 1
mock_graph.user_id = "user-1"
mock_graph.nodes = []
mock_graph.sub_graphs = []
mock_upserted = MagicMock(name="UpsertedLibraryAgent")
@asynccontextmanager
async def fake_tx():
yield None
with (
patch("backend.api.features.library.db.transaction", fake_tx),
patch("prisma.models.LibraryAgent.prisma") as mock_prisma,
patch(
"backend.api.features.library.db.add_generated_agent_image",
new=AsyncMock(),
),
patch(
"backend.api.features.library.model.LibraryAgent.from_db",
return_value=MagicMock(),
),
):
mock_prisma.return_value.upsert = AsyncMock(return_value=mock_upserted)
result = await db.create_library_agent(mock_graph, "user-1")
assert len(result) == 1
upsert_call = mock_prisma.return_value.upsert.call_args
assert upsert_call is not None
# Verify the upsert where clause uses the composite unique key
where = upsert_call.kwargs["where"]
assert "userId_agentGraphId_agentGraphVersion" in where
# Verify the upsert data has both create and update branches
data = upsert_call.kwargs["data"]
assert "create" in data
assert "update" in data
# Verify update branch restores soft-deleted/archived agents
assert data["update"]["isDeleted"] is False
assert data["update"]["isArchived"] is False

View File

@@ -12,6 +12,7 @@ Tests cover:
5. Complete OAuth flow end-to-end
"""
import asyncio
import base64
import hashlib
import secrets
@@ -58,14 +59,27 @@ async def test_user(server, test_user_id: str):
yield test_user_id
# Cleanup - delete in correct order due to foreign key constraints
await PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id})
await PrismaOAuthRefreshToken.prisma().delete_many(where={"userId": test_user_id})
await PrismaOAuthAuthorizationCode.prisma().delete_many(
where={"userId": test_user_id}
)
await PrismaOAuthApplication.prisma().delete_many(where={"ownerId": test_user_id})
await PrismaUser.prisma().delete(where={"id": test_user_id})
# Cleanup - delete in correct order due to foreign key constraints.
# Wrap in try/except because the event loop or Prisma engine may already
# be closed during session teardown on Python 3.12+.
try:
await asyncio.gather(
PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id}),
PrismaOAuthRefreshToken.prisma().delete_many(
where={"userId": test_user_id}
),
PrismaOAuthAuthorizationCode.prisma().delete_many(
where={"userId": test_user_id}
),
)
await asyncio.gather(
PrismaOAuthApplication.prisma().delete_many(
where={"ownerId": test_user_id}
),
PrismaUser.prisma().delete(where={"id": test_user_id}),
)
except RuntimeError:
pass
@pytest_asyncio.fixture

View File

@@ -391,6 +391,11 @@ async def get_available_graph(
async def get_store_agent_by_version_id(
store_listing_version_id: str,
) -> store_model.StoreAgentDetails:
"""Get agent details from the StoreAgent view (APPROVED agents only).
See also: `get_store_agent_details_as_admin()` which bypasses the
APPROVED-only StoreAgent view for admin preview of pending submissions.
"""
logger.debug(f"Getting store agent details for {store_listing_version_id}")
try:
@@ -411,6 +416,57 @@ async def get_store_agent_by_version_id(
raise DatabaseError("Failed to fetch agent details") from e
async def get_store_agent_details_as_admin(
store_listing_version_id: str,
) -> store_model.StoreAgentDetails:
"""Get agent details for admin preview, bypassing the APPROVED-only
StoreAgent view. Queries StoreListingVersion directly so pending
submissions are visible."""
slv = await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id},
include={
"StoreListing": {"include": {"CreatorProfile": True}},
},
)
if not slv or not slv.StoreListing:
raise NotFoundError(
f"Store listing version {store_listing_version_id} not found"
)
listing = slv.StoreListing
# CreatorProfile is a required FK relation — should always exist.
# If it's None, the DB is in a bad state.
profile = listing.CreatorProfile
if not profile:
raise DatabaseError(
f"StoreListing {listing.id} has no CreatorProfile — FK violated"
)
return store_model.StoreAgentDetails(
store_listing_version_id=slv.id,
slug=listing.slug,
agent_name=slv.name,
agent_video=slv.videoUrl or "",
agent_output_demo=slv.agentOutputDemoUrl or "",
agent_image=slv.imageUrls,
creator=profile.username,
creator_avatar=profile.avatarUrl or "",
sub_heading=slv.subHeading,
description=slv.description,
instructions=slv.instructions,
categories=slv.categories,
runs=0,
rating=0.0,
versions=[str(slv.version)],
graph_id=slv.agentGraphId,
graph_versions=[str(slv.agentGraphVersion)],
last_updated=slv.updatedAt,
recommended_schedule_cron=slv.recommendedScheduleCron,
active_version_id=listing.activeVersionId or slv.id,
has_approved_version=listing.hasApprovedVersion,
)
class StoreCreatorsSortOptions(Enum):
# NOTE: values correspond 1:1 to columns of the Creator view
AGENT_RATING = "agent_rating"

View File

@@ -980,14 +980,16 @@ async def execute_graph(
source: Annotated[GraphExecutionSource | None, Body(embed=True)] = None,
graph_version: Optional[int] = None,
preset_id: Optional[str] = None,
dry_run: Annotated[bool, Body(embed=True)] = False,
) -> execution_db.GraphExecutionMeta:
user_credit_model = await get_user_credit_model(user_id)
current_balance = await user_credit_model.get_credits(user_id)
if current_balance <= 0:
raise HTTPException(
status_code=402,
detail="Insufficient balance to execute the agent. Please top up your account.",
)
if not dry_run:
user_credit_model = await get_user_credit_model(user_id)
current_balance = await user_credit_model.get_credits(user_id)
if current_balance <= 0:
raise HTTPException(
status_code=402,
detail="Insufficient balance to execute the agent. Please top up your account.",
)
try:
result = await execution_utils.add_graph_execution(
@@ -997,6 +999,7 @@ async def execute_graph(
preset_id=preset_id,
graph_version=graph_version,
graph_credentials_inputs=credentials_inputs,
dry_run=dry_run,
)
# Record successful graph execution
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)

View File

@@ -18,6 +18,7 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.rate_limit_admin_routes
import backend.api.features.admin.store_admin_routes
import backend.api.features.builder
import backend.api.features.builder.routes
@@ -117,6 +118,11 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Register managed credential providers (e.g. AgentMail)
from backend.integrations.managed_providers import register_all
register_all()
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
@@ -318,6 +324,11 @@ app.include_router(
tags=["v2", "admin"],
prefix="/api/executions",
)
app.include_router(
backend.api.features.admin.rate_limit_admin_routes.router,
tags=["v2", "admin"],
prefix="/api/copilot",
)
app.include_router(
backend.api.features.executions.review.routes.router,
tags=["v2", "executions", "review"],
@@ -528,8 +539,11 @@ class AgentServer(backend.util.service.AppProcess):
user_id: str,
provider: ProviderName,
credentials: Credentials,
) -> Credentials:
from .features.integrations.router import create_credentials, get_credential
):
from backend.api.features.integrations.router import (
create_credentials,
get_credential,
)
try:
return await create_credentials(

View File

@@ -1,3 +1,4 @@
import re
from typing import Any
from backend.blocks._base import (
@@ -19,6 +20,33 @@ from backend.blocks.llm import (
)
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
# Minimum max_output_tokens accepted by OpenAI-compatible APIs.
# A true/false answer fits comfortably within this budget.
MIN_LLM_OUTPUT_TOKENS = 16
def _parse_boolean_response(response_text: str) -> tuple[bool, str | None]:
"""Parse an LLM response into a boolean result.
Returns a ``(result, error)`` tuple. *error* is ``None`` when the
response is unambiguous; otherwise it contains a diagnostic message
and *result* defaults to ``False``.
"""
text = response_text.strip().lower()
if text == "true":
return True, None
if text == "false":
return False, None
# Fuzzy match use word boundaries to avoid false positives like "untrue".
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", text))
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
return True, None
if tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
return False, None
return False, f"Unclear AI response: '{response_text}'"
class AIConditionBlock(AIBlockBase):
"""
@@ -162,54 +190,26 @@ class AIConditionBlock(AIBlockBase):
]
# Call the LLM
try:
response = await self.llm_call(
credentials=credentials,
llm_model=input_data.model,
prompt=prompt,
max_tokens=10, # We only expect a true/false response
response = await self.llm_call(
credentials=credentials,
llm_model=input_data.model,
prompt=prompt,
max_tokens=MIN_LLM_OUTPUT_TOKENS,
)
# Extract the boolean result from the response
result, error = _parse_boolean_response(response.response)
if error:
yield "error", error
# Update internal stats
self.merge_stats(
NodeExecutionStats(
input_token_count=response.prompt_tokens,
output_token_count=response.completion_tokens,
)
# Extract the boolean result from the response
response_text = response.response.strip().lower()
if response_text == "true":
result = True
elif response_text == "false":
result = False
else:
# If the response is not clear, try to interpret it using word boundaries
import re
# Use word boundaries to avoid false positives like 'untrue' or '10'
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", response_text))
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
result = True
elif tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
result = False
else:
# Unclear or conflicting response - default to False and yield error
result = False
yield "error", f"Unclear AI response: '{response.response}'"
# Update internal stats
self.merge_stats(
NodeExecutionStats(
input_token_count=response.prompt_tokens,
output_token_count=response.completion_tokens,
)
)
self.prompt = response.prompt
except Exception as e:
# In case of any error, default to False to be safe
result = False
# Log the error but don't fail the block execution
import logging
logger = logging.getLogger(__name__)
logger.error(f"AI condition evaluation failed: {str(e)}")
yield "error", f"AI evaluation failed: {str(e)}"
)
self.prompt = response.prompt
# Yield results
yield "result", result

View File

@@ -0,0 +1,147 @@
"""Tests for AIConditionBlock regression coverage for max_tokens and error propagation."""
from __future__ import annotations
from typing import cast
import pytest
from backend.blocks.ai_condition import (
MIN_LLM_OUTPUT_TOKENS,
AIConditionBlock,
_parse_boolean_response,
)
from backend.blocks.llm import (
DEFAULT_LLM_MODEL,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
AICredentials,
LLMResponse,
)
_TEST_AI_CREDENTIALS = cast(AICredentials, TEST_CREDENTIALS_INPUT)
# ---------------------------------------------------------------------------
# Helper to collect all yields from the async generator
# ---------------------------------------------------------------------------
async def _collect_outputs(block: AIConditionBlock, input_data, credentials):
outputs: dict[str, object] = {}
async for name, value in block.run(input_data, credentials=credentials):
outputs[name] = value
return outputs
def _make_input(**overrides) -> AIConditionBlock.Input:
defaults: dict = {
"input_value": "hello@example.com",
"condition": "the input is an email address",
"yes_value": "yes!",
"no_value": "no!",
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
}
defaults.update(overrides)
return AIConditionBlock.Input(**defaults)
def _mock_llm_response(response_text: str) -> LLMResponse:
return LLMResponse(
raw_response="",
prompt=[],
response=response_text,
tool_calls=None,
prompt_tokens=10,
completion_tokens=5,
reasoning=None,
)
# ---------------------------------------------------------------------------
# _parse_boolean_response unit tests
# ---------------------------------------------------------------------------
class TestParseBooleanResponse:
def test_true_exact(self):
assert _parse_boolean_response("true") == (True, None)
def test_false_exact(self):
assert _parse_boolean_response("false") == (False, None)
def test_true_with_whitespace(self):
assert _parse_boolean_response(" True ") == (True, None)
def test_yes_fuzzy(self):
assert _parse_boolean_response("Yes") == (True, None)
def test_no_fuzzy(self):
assert _parse_boolean_response("no") == (False, None)
def test_one_fuzzy(self):
assert _parse_boolean_response("1") == (True, None)
def test_zero_fuzzy(self):
assert _parse_boolean_response("0") == (False, None)
def test_unclear_response(self):
result, error = _parse_boolean_response("I'm not sure")
assert result is False
assert error is not None
assert "Unclear" in error
def test_conflicting_tokens(self):
result, error = _parse_boolean_response("true and false")
assert result is False
assert error is not None
# ---------------------------------------------------------------------------
# Regression: max_tokens is set to MIN_LLM_OUTPUT_TOKENS
# ---------------------------------------------------------------------------
class TestMaxTokensRegression:
@pytest.mark.asyncio
async def test_llm_call_receives_min_output_tokens(self):
"""max_tokens must be MIN_LLM_OUTPUT_TOKENS (16) the previous value
of 1 was too low and caused OpenAI to reject the request."""
block = AIConditionBlock()
captured_kwargs: dict = {}
async def spy_llm_call(**kwargs):
captured_kwargs.update(kwargs)
return _mock_llm_response("true")
block.llm_call = spy_llm_call # type: ignore[assignment]
input_data = _make_input()
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)
assert captured_kwargs["max_tokens"] == MIN_LLM_OUTPUT_TOKENS
assert captured_kwargs["max_tokens"] == 16
# ---------------------------------------------------------------------------
# Regression: exceptions from llm_call must propagate
# ---------------------------------------------------------------------------
class TestExceptionPropagation:
@pytest.mark.asyncio
async def test_llm_call_exception_propagates(self):
"""If llm_call raises, the exception must NOT be swallowed.
Previously the block caught all exceptions and silently returned
result=False."""
block = AIConditionBlock()
async def boom(**kwargs):
raise RuntimeError("LLM provider error")
block.llm_call = boom # type: ignore[assignment]
input_data = _make_input()
with pytest.raises(RuntimeError, match="LLM provider error"):
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)

View File

@@ -73,7 +73,7 @@ class ReadDiscordMessagesBlock(Block):
id="df06086a-d5ac-4abb-9996-2ad0acb2eff7",
input_schema=ReadDiscordMessagesBlock.Input, # Assign input schema
output_schema=ReadDiscordMessagesBlock.Output, # Assign output schema
description="Reads messages from a Discord channel using a bot token.",
description="Reads new messages from a Discord channel using a bot token and triggers when a new message is posted",
categories={BlockCategory.SOCIAL},
test_input={
"continuous_read": False,

View File

@@ -1,5 +1,6 @@
import asyncio
import base64
import re
from abc import ABC
from email import encoders
from email.mime.base import MIMEBase
@@ -8,7 +9,7 @@ from email.mime.text import MIMEText
from email.policy import SMTP
from email.utils import getaddresses, parseaddr
from pathlib import Path
from typing import List, Literal, Optional
from typing import List, Literal, Optional, Protocol, runtime_checkable
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
@@ -42,8 +43,52 @@ NO_WRAP_POLICY = SMTP.clone(max_line_length=0)
def serialize_email_recipients(recipients: list[str]) -> str:
"""Serialize recipients list to comma-separated string."""
return ", ".join(recipients)
"""Serialize recipients list to comma-separated string.
Strips leading/trailing whitespace from each address to keep MIME
headers clean (mirrors the strip done in ``validate_email_recipients``).
"""
return ", ".join(addr.strip() for addr in recipients)
# RFC 5322 simplified pattern: local@domain where domain has at least one dot
_EMAIL_RE = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
def validate_email_recipients(recipients: list[str], field_name: str = "to") -> None:
"""Validate that all recipients are plausible email addresses.
Raises ``ValueError`` with a user-friendly message listing every
invalid entry so the caller (or LLM) can correct them in one pass.
"""
invalid = [addr for addr in recipients if not _EMAIL_RE.match(addr.strip())]
if invalid:
formatted = ", ".join(f"'{a}'" for a in invalid)
raise ValueError(
f"Invalid email address(es) in '{field_name}': {formatted}. "
f"Each entry must be a valid email address (e.g. user@example.com)."
)
@runtime_checkable
class HasRecipients(Protocol):
to: list[str]
cc: list[str]
bcc: list[str]
def validate_all_recipients(input_data: HasRecipients) -> None:
"""Validate to/cc/bcc recipient fields on an input namespace.
Calls ``validate_email_recipients`` for ``to`` (required) and
``cc``/``bcc`` (when non-empty), raising ``ValueError`` on the
first field that contains an invalid address.
"""
validate_email_recipients(input_data.to, "to")
if input_data.cc:
validate_email_recipients(input_data.cc, "cc")
if input_data.bcc:
validate_email_recipients(input_data.bcc, "bcc")
def _make_mime_text(
@@ -100,14 +145,16 @@ async def create_mime_message(
) -> str:
"""Create a MIME message with attachments and return base64-encoded raw message."""
validate_all_recipients(input_data)
message = MIMEMultipart()
message["to"] = serialize_email_recipients(input_data.to)
message["subject"] = input_data.subject
if input_data.cc:
message["cc"] = ", ".join(input_data.cc)
message["cc"] = serialize_email_recipients(input_data.cc)
if input_data.bcc:
message["bcc"] = ", ".join(input_data.bcc)
message["bcc"] = serialize_email_recipients(input_data.bcc)
# Use the new helper function with content_type if available
content_type = getattr(input_data, "content_type", None)
@@ -1167,13 +1214,15 @@ async def _build_reply_message(
references.append(headers["message-id"])
# Create MIME message
validate_all_recipients(input_data)
msg = MIMEMultipart()
if input_data.to:
msg["To"] = ", ".join(input_data.to)
msg["To"] = serialize_email_recipients(input_data.to)
if input_data.cc:
msg["Cc"] = ", ".join(input_data.cc)
msg["Cc"] = serialize_email_recipients(input_data.cc)
if input_data.bcc:
msg["Bcc"] = ", ".join(input_data.bcc)
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
msg["Subject"] = subject
if headers.get("message-id"):
msg["In-Reply-To"] = headers["message-id"]
@@ -1685,13 +1734,16 @@ To: {original_to}
else:
body = f"{forward_header}\n\n{original_body}"
# Validate all recipient lists before building the MIME message
validate_all_recipients(input_data)
# Create MIME message
msg = MIMEMultipart()
msg["To"] = ", ".join(input_data.to)
msg["To"] = serialize_email_recipients(input_data.to)
if input_data.cc:
msg["Cc"] = ", ".join(input_data.cc)
msg["Cc"] = serialize_email_recipients(input_data.cc)
if input_data.bcc:
msg["Bcc"] = ", ".join(input_data.bcc)
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
msg["Subject"] = subject
# Add body with proper content type

View File

@@ -28,9 +28,9 @@ class AgentInputBlock(Block):
"""
This block is used to provide input to the graph.
It takes in a value, name, description, default values list and bool to limit selection to default values.
It takes in a value, name, and description.
It Outputs the value passed as input.
It outputs the value passed as input.
"""
class Input(BlockSchemaInput):
@@ -47,12 +47,6 @@ class AgentInputBlock(Block):
default=None,
advanced=True,
)
placeholder_values: list = SchemaField(
description="The placeholder values to be passed as input.",
default_factory=list,
advanced=True,
hidden=True,
)
advanced: bool = SchemaField(
description="Whether to show the input in the advanced section, if the field is not required.",
default=False,
@@ -65,10 +59,7 @@ class AgentInputBlock(Block):
)
def generate_schema(self):
schema = copy.deepcopy(self.get_field_schema("value"))
if possible_values := self.placeholder_values:
schema["enum"] = possible_values
return schema
return copy.deepcopy(self.get_field_schema("value"))
class Output(BlockSchema):
# Use BlockSchema to avoid automatic error field for interface definition
@@ -86,18 +77,16 @@ class AgentInputBlock(Block):
"value": "Hello, World!",
"name": "input_1",
"description": "Example test input.",
"placeholder_values": [],
},
{
"value": "Hello, World!",
"value": 42,
"name": "input_2",
"description": "Example test input with placeholders.",
"placeholder_values": ["Hello, World!"],
"description": "Example numeric input.",
},
],
"test_output": [
("result", "Hello, World!"),
("result", "Hello, World!"),
("result", 42),
],
"categories": {BlockCategory.INPUT, BlockCategory.BASIC},
"block_type": BlockType.INPUT,
@@ -245,13 +234,11 @@ class AgentShortTextInputBlock(AgentInputBlock):
"value": "Hello",
"name": "short_text_1",
"description": "Short text example 1",
"placeholder_values": [],
},
{
"value": "Quick test",
"name": "short_text_2",
"description": "Short text example 2",
"placeholder_values": ["Quick test", "Another option"],
},
],
test_output=[
@@ -285,13 +272,11 @@ class AgentLongTextInputBlock(AgentInputBlock):
"value": "Lorem ipsum dolor sit amet...",
"name": "long_text_1",
"description": "Long text example 1",
"placeholder_values": [],
},
{
"value": "Another multiline text input.",
"name": "long_text_2",
"description": "Long text example 2",
"placeholder_values": ["Another multiline text input."],
},
],
test_output=[
@@ -325,13 +310,11 @@ class AgentNumberInputBlock(AgentInputBlock):
"value": 42,
"name": "number_input_1",
"description": "Number example 1",
"placeholder_values": [],
},
{
"value": 314,
"name": "number_input_2",
"description": "Number example 2",
"placeholder_values": [314, 2718],
},
],
test_output=[
@@ -501,6 +484,12 @@ class AgentDropdownInputBlock(AgentInputBlock):
title="Dropdown Options",
)
def generate_schema(self):
schema = super().generate_schema()
if possible_values := self.placeholder_values:
schema["enum"] = possible_values
return schema
class Output(AgentInputBlock.Output):
result: str = SchemaField(description="Selected dropdown value.")

View File

@@ -104,6 +104,18 @@ class LlmModelMeta(EnumMeta):
class LlmModel(str, Enum, metaclass=LlmModelMeta):
@classmethod
def _missing_(cls, value: object) -> "LlmModel | None":
"""Handle provider-prefixed model names like 'anthropic/claude-sonnet-4-6'."""
if isinstance(value, str) and "/" in value:
stripped = value.split("/", 1)[1]
try:
return cls(stripped)
except ValueError:
return None
return None
# OpenAI models
O3_MINI = "o3-mini"
O3 = "o3-2025-04-16"
@@ -712,6 +724,9 @@ def convert_openai_tool_fmt_to_anthropic(
def extract_openai_reasoning(response) -> str | None:
"""Extract reasoning from OpenAI-compatible response if available."""
"""Note: This will likely not working since the reasoning is not present in another Response API"""
if not response.choices:
logger.warning("LLM response has empty choices in extract_openai_reasoning")
return None
reasoning = None
choice = response.choices[0]
if hasattr(choice, "reasoning") and getattr(choice, "reasoning", None):
@@ -727,6 +742,9 @@ def extract_openai_reasoning(response) -> str | None:
def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
"""Extract tool calls from OpenAI-compatible response."""
if not response.choices:
logger.warning("LLM response has empty choices in extract_openai_tool_calls")
return None
if response.choices[0].message.tool_calls:
return [
ToolContentBlock(
@@ -960,6 +978,8 @@ async def llm_call(
response_format=response_format, # type: ignore
max_tokens=max_tokens,
)
if not response.choices:
raise ValueError("Groq returned empty choices in response")
return LLMResponse(
raw_response=response.choices[0].message,
prompt=prompt,
@@ -1019,12 +1039,8 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
# If there's no response, raise an error
if not response.choices:
if response:
raise ValueError(f"OpenRouter error: {response}")
else:
raise ValueError("No response from OpenRouter.")
raise ValueError(f"OpenRouter returned empty choices: {response}")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -1061,12 +1077,8 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
# If there's no response, raise an error
if not response.choices:
if response:
raise ValueError(f"Llama API error: {response}")
else:
raise ValueError("No response from Llama API.")
raise ValueError(f"Llama API returned empty choices: {response}")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -1096,6 +1108,8 @@ async def llm_call(
messages=prompt, # type: ignore
max_tokens=max_tokens,
)
if not completion.choices:
raise ValueError("AI/ML API returned empty choices in response")
return LLMResponse(
raw_response=completion.choices[0].message,
@@ -1132,6 +1146,9 @@ async def llm_call(
parallel_tool_calls=parallel_tool_calls_param,
)
if not response.choices:
raise ValueError(f"v0 API returned empty choices: {response}")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
@@ -1999,6 +2016,19 @@ class AIConversationBlock(AIBlockBase):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
has_messages = any(
isinstance(m, dict)
and isinstance(m.get("content"), str)
and bool(m["content"].strip())
for m in (input_data.messages or [])
)
has_prompt = bool(input_data.prompt and input_data.prompt.strip())
if not has_messages and not has_prompt:
raise ValueError(
"Cannot call LLM with no messages and no prompt. "
"Provide at least one message or a non-empty prompt."
)
response = await self.llm_call(
AIStructuredResponseGeneratorBlock.Input(
prompt=input_data.prompt,

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,8 @@ import pytest
from backend.blocks import get_blocks
from backend.blocks._base import Block, BlockSchemaInput
from backend.blocks.io import AgentDropdownInputBlock, AgentInputBlock
from backend.data.graph import BaseGraph
from backend.data.model import SchemaField
from backend.util.test import execute_block_test
@@ -279,3 +281,66 @@ class TestAutoCredentialsFieldsValidation:
assert "Duplicate auto_credentials kwarg_name 'credentials'" in str(
exc_info.value
)
def test_agent_input_block_ignores_legacy_placeholder_values():
"""Verify AgentInputBlock.Input.model_construct tolerates extra placeholder_values
for backward compatibility with existing agent JSON."""
legacy_data = {
"name": "url",
"value": "",
"description": "Enter a URL",
"placeholder_values": ["https://example.com"],
}
instance = AgentInputBlock.Input.model_construct(**legacy_data)
schema = instance.generate_schema()
assert (
"enum" not in schema
), "AgentInputBlock should not produce enum from legacy placeholder_values"
def test_dropdown_input_block_produces_enum():
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum."""
options = ["Option A", "Option B"]
instance = AgentDropdownInputBlock.Input.model_construct(
name="choice", value=None, placeholder_values=options
)
schema = instance.generate_schema()
assert schema.get("enum") == options
def test_generate_schema_integration_legacy_placeholder_values():
"""Test the full Graph._generate_schema path with legacy placeholder_values
on AgentInputBlock — verifies no enum leaks through the graph loading path."""
legacy_input_default = {
"name": "url",
"value": "",
"description": "Enter a URL",
"placeholder_values": ["https://example.com"],
}
result = BaseGraph._generate_schema(
(AgentInputBlock.Input, legacy_input_default),
)
url_props = result["properties"]["url"]
assert (
"enum" not in url_props
), "Graph schema should not contain enum from AgentInputBlock placeholder_values"
def test_generate_schema_integration_dropdown_produces_enum():
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
— verifies enum IS produced for dropdown blocks."""
dropdown_input_default = {
"name": "color",
"value": None,
"placeholder_values": ["Red", "Green", "Blue"],
}
result = BaseGraph._generate_schema(
(AgentDropdownInputBlock.Input, dropdown_input_default),
)
color_props = result["properties"]["color"]
assert color_props.get("enum") == [
"Red",
"Green",
"Blue",
], "Graph schema should contain enum from AgentDropdownInputBlock"

View File

@@ -207,6 +207,51 @@ class TestXMLParserBlockSecurity:
pass
class TestXMLParserBlockSyntaxErrors:
"""XML syntax errors should raise ValueError (not SyntaxError).
This ensures the base Block.execute() wraps them as BlockExecutionError
(expected / user-caused) instead of BlockUnknownError (unexpected / alerts
Sentry).
"""
async def test_unclosed_tag_raises_value_error(self):
"""Unclosed tags should raise ValueError, not SyntaxError."""
block = XMLParserBlock()
bad_xml = "<root><unclosed>"
with pytest.raises(ValueError, match="Unclosed tag"):
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
pass
async def test_unexpected_closing_tag_raises_value_error(self):
"""Extra closing tags should raise ValueError, not SyntaxError."""
block = XMLParserBlock()
bad_xml = "</unexpected>"
with pytest.raises(ValueError):
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
pass
async def test_empty_xml_raises_value_error(self):
"""Empty XML input should raise ValueError."""
block = XMLParserBlock()
with pytest.raises(ValueError, match="XML input is empty"):
async for _ in block.run(XMLParserBlock.Input(input_xml="")):
pass
async def test_syntax_error_from_parser_becomes_value_error(self):
"""SyntaxErrors from gravitasml library become ValueError (BlockExecutionError)."""
block = XMLParserBlock()
# Malformed XML that might trigger a SyntaxError from the parser
bad_xml = "<root><child>no closing"
with pytest.raises(ValueError):
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
pass
class TestStoreMediaFileSecurity:
"""Test file storage security limits."""

View File

@@ -488,6 +488,154 @@ class TestLLMStatsTracking:
assert outputs["response"] == {"result": "test"}
class TestAIConversationBlockValidation:
"""Test that AIConversationBlock validates inputs before calling the LLM."""
@pytest.mark.asyncio
async def test_empty_messages_and_empty_prompt_raises_error(self):
"""Empty messages with no prompt should raise ValueError, not a cryptic API error."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_empty_messages_with_prompt_succeeds(self):
"""Empty messages but a non-empty prompt should proceed without error."""
block = llm.AIConversationBlock()
async def mock_llm_call(input_data, credentials):
return {"response": "OK"}
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
input_data = llm.AIConversationBlock.Input(
messages=[],
prompt="Hello, how are you?",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
outputs = {}
async for name, data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[name] = data
assert outputs["response"] == "OK"
@pytest.mark.asyncio
async def test_nonempty_messages_with_empty_prompt_succeeds(self):
"""Non-empty messages with no prompt should proceed without error."""
block = llm.AIConversationBlock()
async def mock_llm_call(input_data, credentials):
return {"response": "response from conversation"}
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": "Hello"}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
outputs = {}
async for name, data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[name] = data
assert outputs["response"] == "response from conversation"
@pytest.mark.asyncio
async def test_messages_with_empty_content_raises_error(self):
"""Messages with empty content strings should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": ""}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_whitespace_content_raises_error(self):
"""Messages with whitespace-only content should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": " "}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_none_entry_raises_error(self):
"""Messages list containing None should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[None],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_empty_dict_raises_error(self):
"""Messages list containing empty dict should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_none_content_raises_error(self):
"""Messages with content=None should not crash with AttributeError."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": None}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
class TestAITextSummarizerValidation:
"""Test that AITextSummarizerBlock validates LLM responses are strings."""
@@ -809,3 +957,33 @@ class TestUserErrorStatusCodeHandling:
mock_warning.assert_called_once()
mock_exception.assert_not_called()
class TestLlmModelMissing:
"""Test that LlmModel handles provider-prefixed model names."""
def test_provider_prefixed_model_resolves(self):
"""Provider-prefixed model string should resolve to the correct enum member."""
assert (
llm.LlmModel("anthropic/claude-sonnet-4-6")
== llm.LlmModel.CLAUDE_4_6_SONNET
)
def test_bare_model_still_works(self):
"""Bare (non-prefixed) model string should still resolve correctly."""
assert llm.LlmModel("claude-sonnet-4-6") == llm.LlmModel.CLAUDE_4_6_SONNET
def test_invalid_prefixed_model_raises(self):
"""Unknown provider-prefixed model string should raise ValueError."""
with pytest.raises(ValueError):
llm.LlmModel("invalid/nonexistent-model")
def test_slash_containing_value_direct_lookup(self):
"""Enum values with '/' (e.g., OpenRouter models) should resolve via direct lookup, not _missing_."""
assert llm.LlmModel("google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
def test_double_prefixed_slash_model(self):
"""Double-prefixed value should still resolve by stripping first prefix."""
assert (
llm.LlmModel("extra/google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
)

View File

@@ -0,0 +1,87 @@
"""Tests for empty-choices guard in extract_openai_tool_calls() and extract_openai_reasoning()."""
from unittest.mock import MagicMock
from backend.blocks.llm import extract_openai_reasoning, extract_openai_tool_calls
class TestExtractOpenaiToolCallsEmptyChoices:
"""extract_openai_tool_calls() must return None when choices is empty."""
def test_returns_none_for_empty_choices(self):
response = MagicMock()
response.choices = []
assert extract_openai_tool_calls(response) is None
def test_returns_none_for_none_choices(self):
response = MagicMock()
response.choices = None
assert extract_openai_tool_calls(response) is None
def test_returns_tool_calls_when_choices_present(self):
tool = MagicMock()
tool.id = "call_1"
tool.type = "function"
tool.function.name = "my_func"
tool.function.arguments = '{"a": 1}'
message = MagicMock()
message.tool_calls = [tool]
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
result = extract_openai_tool_calls(response)
assert result is not None
assert len(result) == 1
assert result[0].function.name == "my_func"
def test_returns_none_when_no_tool_calls(self):
message = MagicMock()
message.tool_calls = None
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
assert extract_openai_tool_calls(response) is None
class TestExtractOpenaiReasoningEmptyChoices:
"""extract_openai_reasoning() must return None when choices is empty."""
def test_returns_none_for_empty_choices(self):
response = MagicMock()
response.choices = []
assert extract_openai_reasoning(response) is None
def test_returns_none_for_none_choices(self):
response = MagicMock()
response.choices = None
assert extract_openai_reasoning(response) is None
def test_returns_reasoning_from_choice(self):
choice = MagicMock()
choice.reasoning = "Step-by-step reasoning"
choice.message = MagicMock(spec=[]) # no 'reasoning' attr on message
response = MagicMock(spec=[]) # no 'reasoning' attr on response
response.choices = [choice]
result = extract_openai_reasoning(response)
assert result == "Step-by-step reasoning"
def test_returns_none_when_no_reasoning(self):
choice = MagicMock(spec=[]) # no 'reasoning' attr
choice.message = MagicMock(spec=[]) # no 'reasoning' attr
response = MagicMock(spec=[]) # no 'reasoning' attr
response.choices = [choice]
result = extract_openai_reasoning(response)
assert result is None

View File

@@ -1074,6 +1074,7 @@ async def test_orchestrator_uses_customized_name_for_blocks():
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
mock_node.block = StoreValueBlock()
mock_node.input_default = {}
# Create a mock link
mock_link = MagicMock(spec=Link)
@@ -1105,6 +1106,7 @@ async def test_orchestrator_falls_back_to_block_name():
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {} # No customized_name
mock_node.block = StoreValueBlock()
mock_node.input_default = {}
# Create a mock link
mock_link = MagicMock(spec=Link)

View File

@@ -0,0 +1,202 @@
"""Tests for ExecutionMode enum and provider validation in the orchestrator.
Covers:
- ExecutionMode enum members exist and have stable values
- EXTENDED_THINKING provider validation (anthropic/open_router allowed, others rejected)
- EXTENDED_THINKING model-name validation (must start with "claude")
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks.llm import LlmModel
from backend.blocks.orchestrator import ExecutionMode, OrchestratorBlock
# ---------------------------------------------------------------------------
# ExecutionMode enum integrity
# ---------------------------------------------------------------------------
class TestExecutionModeEnum:
"""Guard against accidental renames or removals of enum members."""
def test_built_in_exists(self):
assert hasattr(ExecutionMode, "BUILT_IN")
assert ExecutionMode.BUILT_IN.value == "built_in"
def test_extended_thinking_exists(self):
assert hasattr(ExecutionMode, "EXTENDED_THINKING")
assert ExecutionMode.EXTENDED_THINKING.value == "extended_thinking"
def test_exactly_two_members(self):
"""If a new mode is added, this test should be updated intentionally."""
assert set(ExecutionMode.__members__.keys()) == {
"BUILT_IN",
"EXTENDED_THINKING",
}
def test_string_enum(self):
"""ExecutionMode is a str enum so it serialises cleanly to JSON."""
assert isinstance(ExecutionMode.BUILT_IN, str)
assert isinstance(ExecutionMode.EXTENDED_THINKING, str)
def test_round_trip_from_value(self):
"""Constructing from the string value should return the same member."""
assert ExecutionMode("built_in") is ExecutionMode.BUILT_IN
assert ExecutionMode("extended_thinking") is ExecutionMode.EXTENDED_THINKING
# ---------------------------------------------------------------------------
# Provider validation (inline in OrchestratorBlock.run)
# ---------------------------------------------------------------------------
def _make_model_stub(provider: str, value: str):
"""Create a lightweight stub that behaves like LlmModel for validation."""
metadata = MagicMock()
metadata.provider = provider
stub = MagicMock()
stub.metadata = metadata
stub.value = value
return stub
class TestExtendedThinkingProviderValidation:
"""The orchestrator rejects EXTENDED_THINKING for non-Anthropic providers."""
def test_anthropic_provider_accepted(self):
"""provider='anthropic' + claude model should not raise."""
model = _make_model_stub("anthropic", "claude-opus-4-6")
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
def test_open_router_provider_accepted(self):
"""provider='open_router' + claude model should not raise."""
model = _make_model_stub("open_router", "claude-sonnet-4-6")
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
def test_openai_provider_rejected(self):
"""provider='openai' should be rejected for EXTENDED_THINKING."""
model = _make_model_stub("openai", "gpt-4o")
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_groq_provider_rejected(self):
model = _make_model_stub("groq", "llama-3.3-70b-versatile")
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_non_claude_model_rejected_even_if_anthropic_provider(self):
"""A hypothetical non-Claude model with provider='anthropic' is rejected."""
model = _make_model_stub("anthropic", "not-a-claude-model")
model_name = model.value
assert not model_name.startswith("claude")
def test_real_gpt4o_model_rejected(self):
"""Verify a real LlmModel enum member (GPT4O) fails the provider check."""
model = LlmModel.GPT4O
provider = model.metadata.provider
assert provider not in ("anthropic", "open_router")
def test_real_claude_model_passes(self):
"""Verify a real LlmModel enum member (CLAUDE_4_6_SONNET) passes."""
model = LlmModel.CLAUDE_4_6_SONNET
provider = model.metadata.provider
model_name = model.value
assert provider in ("anthropic", "open_router")
assert model_name.startswith("claude")
# ---------------------------------------------------------------------------
# Integration-style: exercise the validation branch via OrchestratorBlock.run
# ---------------------------------------------------------------------------
def _make_input_data(model, execution_mode=ExecutionMode.EXTENDED_THINKING):
"""Build a minimal MagicMock that satisfies OrchestratorBlock.run's early path."""
inp = MagicMock()
inp.execution_mode = execution_mode
inp.model = model
inp.prompt = "test"
inp.sys_prompt = ""
inp.conversation_history = []
inp.last_tool_output = None
inp.prompt_values = {}
return inp
async def _collect_run_outputs(block, input_data, **kwargs):
"""Exhaust the OrchestratorBlock.run async generator, collecting outputs."""
outputs = []
async for item in block.run(input_data, **kwargs):
outputs.append(item)
return outputs
class TestExtendedThinkingValidationRaisesInBlock:
"""Call OrchestratorBlock.run far enough to trigger the ValueError."""
@pytest.mark.asyncio
async def test_non_anthropic_provider_raises_valueerror(self):
"""EXTENDED_THINKING + openai provider raises ValueError."""
block = OrchestratorBlock()
input_data = _make_input_data(model=LlmModel.GPT4O)
with (
patch.object(
block,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
),
pytest.raises(ValueError, match="Anthropic-compatible"),
):
await _collect_run_outputs(
block,
input_data,
credentials=MagicMock(),
graph_id="g",
node_id="n",
graph_exec_id="ge",
node_exec_id="ne",
user_id="u",
graph_version=1,
execution_context=MagicMock(),
execution_processor=MagicMock(),
)
@pytest.mark.asyncio
async def test_non_claude_model_with_anthropic_provider_raises(self):
"""A model with anthropic provider but non-claude name raises ValueError."""
block = OrchestratorBlock()
fake_model = _make_model_stub("anthropic", "not-a-claude-model")
input_data = _make_input_data(model=fake_model)
with (
patch.object(
block,
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
),
pytest.raises(ValueError, match="only supports Claude models"),
):
await _collect_run_outputs(
block,
input_data,
credentials=MagicMock(),
graph_id="g",
node_id="n",
graph_exec_id="ge",
node_exec_id="ne",
user_id="u",
graph_version=1,
execution_context=MagicMock(),
execution_processor=MagicMock(),
)

File diff suppressed because it is too large Load Diff

View File

@@ -44,7 +44,7 @@ class XMLParserBlock(Block):
elif token.type == "TAG_CLOSE":
depth -= 1
if depth < 0:
raise SyntaxError("Unexpected closing tag in XML input.")
raise ValueError("Unexpected closing tag in XML input.")
elif token.type in {"TEXT", "ESCAPE"}:
if depth == 0 and token.value:
raise ValueError(
@@ -53,7 +53,7 @@ class XMLParserBlock(Block):
)
if depth != 0:
raise SyntaxError("Unclosed tag detected in XML input.")
raise ValueError("Unclosed tag detected in XML input.")
if not root_seen:
raise ValueError("XML must include a root element.")
@@ -76,4 +76,7 @@ class XMLParserBlock(Block):
except ValueError as val_e:
raise ValueError(f"Validation error for dict:{val_e}") from val_e
except SyntaxError as syn_e:
raise SyntaxError(f"Error in input xml syntax: {syn_e}") from syn_e
# Raise as ValueError so the base Block.execute() wraps it as
# BlockExecutionError (expected user-caused failure) instead of
# BlockUnknownError (unexpected platform error that alerts Sentry).
raise ValueError(f"Error in input xml syntax: {syn_e}") from syn_e

View File

@@ -9,11 +9,14 @@ shared tool registry as the SDK path.
import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator
from typing import Any
from collections.abc import AsyncGenerator, Sequence
from dataclasses import dataclass, field
from functools import partial
from typing import Any, cast
import orjson
from langfuse import propagate_attributes
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
from backend.copilot.model import (
ChatMessage,
@@ -48,7 +51,17 @@ from backend.copilot.token_tracking import persist_and_record_usage
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
from backend.util.prompt import compress_context
from backend.util.prompt import (
compress_context,
estimate_token_count,
estimate_token_count_str,
)
from backend.util.tool_call_loop import (
LLMLoopResponse,
LLMToolCall,
ToolCallResult,
tool_call_loop,
)
logger = logging.getLogger(__name__)
@@ -59,6 +72,247 @@ _background_tasks: set[asyncio.Task[Any]] = set()
_MAX_TOOL_ROUNDS = 30
@dataclass
class _BaselineStreamState:
"""Mutable state shared between the tool-call loop callbacks.
Extracted from ``stream_chat_completion_baseline`` so that the callbacks
can be module-level functions instead of deeply nested closures.
"""
pending_events: list[StreamBaseResponse] = field(default_factory=list)
assistant_text: str = ""
text_block_id: str = field(default_factory=lambda: str(uuid.uuid4()))
text_started: bool = False
turn_prompt_tokens: int = 0
turn_completion_tokens: int = 0
async def _baseline_llm_caller(
messages: list[dict[str, Any]],
tools: Sequence[Any],
*,
state: _BaselineStreamState,
) -> LLMLoopResponse:
"""Stream an OpenAI-compatible response and collect results.
Extracted from ``stream_chat_completion_baseline`` for readability.
"""
state.pending_events.append(StreamStartStep())
round_text = ""
try:
client = _get_openai_client()
typed_messages = cast(list[ChatCompletionMessageParam], messages)
if tools:
typed_tools = cast(list[ChatCompletionToolParam], tools)
response = await client.chat.completions.create(
model=config.model,
messages=typed_messages,
tools=typed_tools,
stream=True,
stream_options={"include_usage": True},
)
else:
response = await client.chat.completions.create(
model=config.model,
messages=typed_messages,
stream=True,
stream_options={"include_usage": True},
)
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
if chunk.usage:
state.turn_prompt_tokens += chunk.usage.prompt_tokens or 0
state.turn_completion_tokens += chunk.usage.completion_tokens or 0
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
if delta.content:
if not state.text_started:
state.pending_events.append(StreamTextStart(id=state.text_block_id))
state.text_started = True
round_text += delta.content
state.pending_events.append(
StreamTextDelta(id=state.text_block_id, delta=delta.content)
)
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in tool_calls_by_index:
tool_calls_by_index[idx] = {
"id": "",
"name": "",
"arguments": "",
}
entry = tool_calls_by_index[idx]
if tc.id:
entry["id"] = tc.id
if tc.function and tc.function.name:
entry["name"] = tc.function.name
if tc.function and tc.function.arguments:
entry["arguments"] += tc.function.arguments
# Close text block
if state.text_started:
state.pending_events.append(StreamTextEnd(id=state.text_block_id))
state.text_started = False
state.text_block_id = str(uuid.uuid4())
finally:
# Always persist partial text so the session history stays consistent,
# even when the stream is interrupted by an exception.
state.assistant_text += round_text
# Always emit StreamFinishStep to match the StreamStartStep,
# even if an exception occurred during streaming.
state.pending_events.append(StreamFinishStep())
# Convert to shared format
llm_tool_calls = [
LLMToolCall(
id=tc["id"],
name=tc["name"],
arguments=tc["arguments"] or "{}",
)
for tc in tool_calls_by_index.values()
]
return LLMLoopResponse(
response_text=round_text or None,
tool_calls=llm_tool_calls,
raw_response=None, # Not needed for baseline conversation updater
prompt_tokens=0, # Tracked via state accumulators
completion_tokens=0,
)
async def _baseline_tool_executor(
tool_call: LLMToolCall,
tools: Sequence[Any],
*,
state: _BaselineStreamState,
user_id: str | None,
session: ChatSession,
) -> ToolCallResult:
"""Execute a tool via the copilot tool registry.
Extracted from ``stream_chat_completion_baseline`` for readability.
"""
tool_call_id = tool_call.id
tool_name = tool_call.name
raw_args = tool_call.arguments or "{}"
try:
tool_args = orjson.loads(raw_args)
except orjson.JSONDecodeError as parse_err:
parse_error = f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
logger.warning("[Baseline] %s", parse_error)
state.pending_events.append(
StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=parse_error,
success=False,
)
)
return ToolCallResult(
tool_call_id=tool_call_id,
tool_name=tool_name,
content=parse_error,
is_error=True,
)
state.pending_events.append(
StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
)
state.pending_events.append(
StreamToolInputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
input=tool_args,
)
)
try:
result: StreamToolOutputAvailable = await execute_tool(
tool_name=tool_name,
parameters=tool_args,
user_id=user_id,
session=session,
tool_call_id=tool_call_id,
)
state.pending_events.append(result)
tool_output = (
result.output if isinstance(result.output, str) else str(result.output)
)
return ToolCallResult(
tool_call_id=tool_call_id,
tool_name=tool_name,
content=tool_output,
)
except Exception as e:
error_output = f"Tool execution error: {e}"
logger.error(
"[Baseline] Tool %s failed: %s",
tool_name,
error_output,
exc_info=True,
)
state.pending_events.append(
StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=error_output,
success=False,
)
)
return ToolCallResult(
tool_call_id=tool_call_id,
tool_name=tool_name,
content=error_output,
is_error=True,
)
def _baseline_conversation_updater(
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
) -> None:
"""Update OpenAI message list with assistant response + tool results.
Extracted from ``stream_chat_completion_baseline`` for readability.
"""
if tool_results:
# Build assistant message with tool_calls
assistant_msg: dict[str, Any] = {"role": "assistant"}
if response.response_text:
assistant_msg["content"] = response.response_text
assistant_msg["tool_calls"] = [
{
"id": tc.id,
"type": "function",
"function": {"name": tc.name, "arguments": tc.arguments},
}
for tc in response.tool_calls
]
messages.append(assistant_msg)
for tr in tool_results:
messages.append(
{
"role": "tool",
"tool_call_id": tr.tool_call_id,
"content": tr.content,
}
)
else:
if response.response_text:
messages.append({"role": "assistant", "content": response.response_text})
async def _update_title_async(
session_id: str, message: str, user_id: str | None
) -> None:
@@ -219,191 +473,32 @@ async def stream_chat_completion_baseline(
except Exception:
logger.warning("[Baseline] Langfuse trace context setup failed")
assistant_text = ""
text_block_id = str(uuid.uuid4())
text_started = False
step_open = False
# Token usage accumulators — populated from streaming chunks
turn_prompt_tokens = 0
turn_completion_tokens = 0
_stream_error = False # Track whether an error occurred during streaming
state = _BaselineStreamState()
# Bind extracted module-level callbacks to this request's state/session
# using functools.partial so they satisfy the Protocol signatures.
_bound_llm_caller = partial(_baseline_llm_caller, state=state)
_bound_tool_executor = partial(
_baseline_tool_executor, state=state, user_id=user_id, session=session
)
try:
for _round in range(_MAX_TOOL_ROUNDS):
# Open a new step for each LLM round
yield StreamStartStep()
step_open = True
loop_result = None
async for loop_result in tool_call_loop(
messages=openai_messages,
tools=tools,
llm_call=_bound_llm_caller,
execute_tool=_bound_tool_executor,
update_conversation=_baseline_conversation_updater,
max_iterations=_MAX_TOOL_ROUNDS,
):
# Drain buffered events after each iteration (real-time streaming)
for evt in state.pending_events:
yield evt
state.pending_events.clear()
# Stream a response from the model
create_kwargs: dict[str, Any] = dict(
model=config.model,
messages=openai_messages,
stream=True,
stream_options={"include_usage": True},
)
if tools:
create_kwargs["tools"] = tools
response = await _get_openai_client().chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
# Accumulate streamed response (text + tool calls)
round_text = ""
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
# Capture token usage from the streaming chunk.
# OpenRouter normalises all providers into OpenAI format
# where prompt_tokens already includes cached tokens
# (unlike Anthropic's native API). Use += to sum all
# tool-call rounds since each API call is independent.
# NOTE: stream_options={"include_usage": True} is not
# universally supported — some providers (Mistral, Llama
# via OpenRouter) always return chunk.usage=None. When
# that happens, tokens stay 0 and the tiktoken fallback
# below activates. Fail-open: one round is estimated.
if chunk.usage:
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
turn_completion_tokens += chunk.usage.completion_tokens or 0
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
# Text content
if delta.content:
if not text_started:
yield StreamTextStart(id=text_block_id)
text_started = True
round_text += delta.content
yield StreamTextDelta(id=text_block_id, delta=delta.content)
# Tool call fragments (streamed incrementally)
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in tool_calls_by_index:
tool_calls_by_index[idx] = {
"id": "",
"name": "",
"arguments": "",
}
entry = tool_calls_by_index[idx]
if tc.id:
entry["id"] = tc.id
if tc.function and tc.function.name:
entry["name"] = tc.function.name
if tc.function and tc.function.arguments:
entry["arguments"] += tc.function.arguments
# Close text block if we had one this round
if text_started:
yield StreamTextEnd(id=text_block_id)
text_started = False
text_block_id = str(uuid.uuid4())
# Accumulate text for session persistence
assistant_text += round_text
# No tool calls -> model is done
if not tool_calls_by_index:
yield StreamFinishStep()
step_open = False
break
# Close step before tool execution
yield StreamFinishStep()
step_open = False
# Append the assistant message with tool_calls to context.
assistant_msg: dict[str, Any] = {"role": "assistant"}
if round_text:
assistant_msg["content"] = round_text
assistant_msg["tool_calls"] = [
{
"id": tc["id"],
"type": "function",
"function": {
"name": tc["name"],
"arguments": tc["arguments"] or "{}",
},
}
for tc in tool_calls_by_index.values()
]
openai_messages.append(assistant_msg)
# Execute each tool call and stream events
for tc in tool_calls_by_index.values():
tool_call_id = tc["id"]
tool_name = tc["name"]
raw_args = tc["arguments"] or "{}"
try:
tool_args = orjson.loads(raw_args)
except orjson.JSONDecodeError as parse_err:
parse_error = (
f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
)
logger.warning("[Baseline] %s", parse_error)
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=parse_error,
success=False,
)
openai_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": parse_error,
}
)
continue
yield StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
yield StreamToolInputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
input=tool_args,
)
# Execute via shared tool registry
try:
result: StreamToolOutputAvailable = await execute_tool(
tool_name=tool_name,
parameters=tool_args,
user_id=user_id,
session=session,
tool_call_id=tool_call_id,
)
yield result
tool_output = (
result.output
if isinstance(result.output, str)
else str(result.output)
)
except Exception as e:
error_output = f"Tool execution error: {e}"
logger.error(
"[Baseline] Tool %s failed: %s",
tool_name,
error_output,
exc_info=True,
)
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=error_output,
success=False,
)
tool_output = error_output
# Append tool result to context for next round
openai_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": tool_output,
}
)
else:
# for-loop exhausted without break -> tool-round limit hit
if loop_result and not loop_result.finished_naturally:
limit_msg = (
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
"without a final response."
@@ -418,11 +513,28 @@ async def stream_chat_completion_baseline(
_stream_error = True
error_msg = str(e) or type(e).__name__
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
# Close any open text/step before emitting error
if text_started:
yield StreamTextEnd(id=text_block_id)
if step_open:
yield StreamFinishStep()
# Close any open text block. The llm_caller's finally block
# already appended StreamFinishStep to pending_events, so we must
# insert StreamTextEnd *before* StreamFinishStep to preserve the
# protocol ordering:
# StreamStartStep -> StreamTextStart -> ...deltas... ->
# StreamTextEnd -> StreamFinishStep
# Appending (or yielding directly) would place it after
# StreamFinishStep, violating the protocol.
if state.text_started:
# Find the last StreamFinishStep and insert before it.
insert_pos = len(state.pending_events)
for i in range(len(state.pending_events) - 1, -1, -1):
if isinstance(state.pending_events[i], StreamFinishStep):
insert_pos = i
break
state.pending_events.insert(
insert_pos, StreamTextEnd(id=state.text_block_id)
)
# Drain pending events in correct order
for evt in state.pending_events:
yield evt
state.pending_events.clear()
yield StreamError(errorText=error_msg, code="baseline_error")
# Still persist whatever we got
finally:
@@ -442,26 +554,21 @@ async def stream_chat_completion_baseline(
# Skip fallback when an error occurred and no output was produced —
# charging rate-limit tokens for completely failed requests is unfair.
if (
turn_prompt_tokens == 0
and turn_completion_tokens == 0
and not (_stream_error and not assistant_text)
state.turn_prompt_tokens == 0
and state.turn_completion_tokens == 0
and not (_stream_error and not state.assistant_text)
):
from backend.util.prompt import (
estimate_token_count,
estimate_token_count_str,
)
turn_prompt_tokens = max(
state.turn_prompt_tokens = max(
estimate_token_count(openai_messages, model=config.model), 1
)
turn_completion_tokens = estimate_token_count_str(
assistant_text, model=config.model
state.turn_completion_tokens = estimate_token_count_str(
state.assistant_text, model=config.model
)
logger.info(
"[Baseline] No streaming usage reported; estimated tokens: "
"prompt=%d, completion=%d",
turn_prompt_tokens,
turn_completion_tokens,
state.turn_prompt_tokens,
state.turn_completion_tokens,
)
# Persist token usage to session and record for rate limiting.
@@ -471,15 +578,15 @@ async def stream_chat_completion_baseline(
await persist_and_record_usage(
session=session,
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
prompt_tokens=state.turn_prompt_tokens,
completion_tokens=state.turn_completion_tokens,
log_prefix="[Baseline]",
)
# Persist assistant response
if assistant_text:
if state.assistant_text:
session.messages.append(
ChatMessage(role="assistant", content=assistant_text)
ChatMessage(role="assistant", content=state.assistant_text)
)
try:
await upsert_chat_session(session)
@@ -491,11 +598,11 @@ async def stream_chat_completion_baseline(
# aclose() — doing so raises RuntimeError on client disconnect.
# On GeneratorExit the client is already gone, so unreachable yields
# are harmless; on normal completion they reach the SSE stream.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
if state.turn_prompt_tokens > 0 or state.turn_completion_tokens > 0:
yield StreamUsage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=turn_prompt_tokens + turn_completion_tokens,
prompt_tokens=state.turn_prompt_tokens,
completion_tokens=state.turn_completion_tokens,
total_tokens=state.turn_prompt_tokens + state.turn_completion_tokens,
)
yield StreamFinish()

View File

@@ -91,6 +91,20 @@ class ChatConfig(BaseSettings):
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
)
# Cost (in credits / cents) to reset the daily rate limit using credits.
# When a user hits their daily limit, they can spend this amount to reset
# the daily counter and keep working. Set to 0 to disable the feature.
rate_limit_reset_cost: int = Field(
default=500,
ge=0,
description="Credit cost (in cents) for resetting the daily rate limit. 0 = disabled.",
)
max_daily_resets: int = Field(
default=5,
ge=0,
description="Maximum number of credit-based rate limit resets per user per day. 0 = unlimited.",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,
@@ -164,7 +178,7 @@ class ChatConfig(BaseSettings):
Single source of truth for "will the SDK route through OpenRouter?".
Checks the flag *and* that ``api_key`` + a valid ``base_url`` are
present — mirrors the fallback logic in ``_build_sdk_env``.
present — mirrors the fallback logic in ``build_sdk_env``.
"""
if not self.use_openrouter:
return False

View File

@@ -104,17 +104,32 @@ def get_sdk_cwd() -> str:
E2B_WORKDIR = "/home/user"
E2B_ALLOWED_DIRS: tuple[str, ...] = (E2B_WORKDIR, "/tmp")
E2B_ALLOWED_DIRS_STR: str = " or ".join(E2B_ALLOWED_DIRS)
def is_within_allowed_dirs(path: str) -> bool:
"""Return True if *path* is within one of the allowed sandbox directories."""
for allowed in E2B_ALLOWED_DIRS:
if path == allowed or path.startswith(allowed + "/"):
return True
return False
def resolve_sandbox_path(path: str) -> str:
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
"""Normalise *path* to an absolute sandbox path under an allowed directory.
Allowed directories: ``/home/user`` and ``/tmp``.
Relative paths are resolved against ``/home/user``.
Raises :class:`ValueError` if the resolved path escapes the sandbox.
"""
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
normalized = os.path.normpath(candidate)
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
if not is_within_allowed_dirs(normalized):
raise ValueError(
f"Path must be within {E2B_ALLOWED_DIRS_STR}: {os.path.basename(path)}"
)
return normalized

View File

@@ -198,10 +198,32 @@ def test_resolve_sandbox_path_normalizes_dots():
def test_resolve_sandbox_path_escape_raises():
with pytest.raises(ValueError, match="/home/user"):
with pytest.raises(ValueError, match="must be within"):
resolve_sandbox_path("/home/user/../../etc/passwd")
def test_resolve_sandbox_path_absolute_outside_raises():
with pytest.raises(ValueError, match="/home/user"):
with pytest.raises(ValueError):
resolve_sandbox_path("/etc/passwd")
def test_resolve_sandbox_path_tmp_allowed():
assert resolve_sandbox_path("/tmp/data.txt") == "/tmp/data.txt"
def test_resolve_sandbox_path_tmp_nested():
assert resolve_sandbox_path("/tmp/a/b/c.txt") == "/tmp/a/b/c.txt"
def test_resolve_sandbox_path_tmp_itself():
assert resolve_sandbox_path("/tmp") == "/tmp"
def test_resolve_sandbox_path_tmp_escape_raises():
with pytest.raises(ValueError):
resolve_sandbox_path("/tmp/../etc/passwd")
def test_resolve_sandbox_path_tmp_prefix_collision_raises():
with pytest.raises(ValueError):
resolve_sandbox_path("/tmp_evil/malicious.txt")

View File

@@ -18,7 +18,7 @@ from prisma.types import (
from backend.data import db
from backend.util.json import SafeJson, sanitize_string
from .model import ChatMessage, ChatSession, ChatSessionInfo
from .model import ChatMessage, ChatSession, ChatSessionInfo, invalidate_session_cache
logger = logging.getLogger(__name__)
@@ -217,6 +217,9 @@ async def add_chat_messages_batch(
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
if msg.get("duration_ms") is not None:
data["durationMs"] = msg["duration_ms"]
messages_data.append(data)
# Run create_many and session update in parallel within transaction
@@ -359,3 +362,22 @@ async def update_tool_message_content(
f"tool_call_id {tool_call_id}: {e}"
)
return False
async def set_turn_duration(session_id: str, duration_ms: int) -> None:
"""Set durationMs on the last assistant message in a session.
Also invalidates the Redis session cache so the next GET returns
the updated duration.
"""
last_msg = await PrismaChatMessage.prisma().find_first(
where={"sessionId": session_id, "role": "assistant"},
order={"sequence": "desc"},
)
if last_msg:
await PrismaChatMessage.prisma().update(
where={"id": last_msg.id},
data={"durationMs": duration_ms},
)
# Invalidate cache so the session is re-fetched from DB with durationMs
await invalidate_session_cache(session_id)

View File

@@ -14,7 +14,7 @@ import time
from backend.copilot import stream_registry
from backend.copilot.baseline import stream_chat_completion_baseline
from backend.copilot.config import ChatConfig
from backend.copilot.response_model import StreamFinish
from backend.copilot.response_model import StreamError
from backend.copilot.sdk import service as sdk_service
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
from backend.executor.cluster_lock import ClusterLock
@@ -23,6 +23,7 @@ from backend.util.feature_flag import Flag, is_feature_enabled
from backend.util.logging import TruncatedLogger, configure_logging
from backend.util.process import set_service_name
from backend.util.retry import func_retry
from backend.util.workspace_storage import shutdown_workspace_storage
from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
@@ -153,8 +154,6 @@ class CoPilotProcessor:
worker's event loop, ensuring ``aiohttp.ClientSession.close()``
runs on the same loop that created the session.
"""
from backend.util.workspace_storage import shutdown_workspace_storage
coro = shutdown_workspace_storage()
try:
future = asyncio.run_coroutine_threadsafe(coro, self.execution_loop)
@@ -268,35 +267,37 @@ class CoPilotProcessor:
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
# Stream chat completion and publish chunks to Redis.
async for chunk in stream_fn(
# stream_and_publish wraps the raw stream with registry
# publishing (shared with collect_copilot_response).
raw_stream = stream_fn(
session_id=entry.session_id,
message=entry.message if entry.message else None,
is_user_message=entry.is_user_message,
user_id=entry.user_id,
context=entry.context,
file_ids=entry.file_ids,
)
async for chunk in stream_registry.stream_and_publish(
session_id=entry.session_id,
turn_id=entry.turn_id,
stream=raw_stream,
):
if cancel.is_set():
log.info("Cancel requested, breaking stream")
break
# Capture StreamError so mark_session_completed receives
# the error message (stream_and_publish yields but does
# not publish StreamError — that's done by mark_session_completed).
if isinstance(chunk, StreamError):
error_msg = chunk.errorText
break
current_time = time.monotonic()
if current_time - last_refresh >= refresh_interval:
cluster_lock.refresh()
last_refresh = current_time
# Skip StreamFinish — mark_session_completed publishes it.
if isinstance(chunk, StreamFinish):
continue
try:
await stream_registry.publish_chunk(entry.turn_id, chunk)
except Exception as e:
log.error(
f"Error publishing chunk {type(chunk).__name__}: {e}",
exc_info=True,
)
# Stream loop completed
if cancel.is_set():
log.info("Stream cancelled by user")

View File

@@ -54,6 +54,7 @@ class ChatMessage(BaseModel):
refusal: str | None = None
tool_calls: list[dict] | None = None
function_call: dict | None = None
duration_ms: int | None = None
@staticmethod
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
@@ -66,6 +67,7 @@ class ChatMessage(BaseModel):
refusal=prisma_message.refusal,
tool_calls=_parse_json_field(prisma_message.toolCalls),
function_call=_parse_json_field(prisma_message.functionCall),
duration_ms=prisma_message.durationMs,
)

View File

@@ -205,9 +205,10 @@ Important files (code, configs, outputs) should be saved to workspace to ensure
### SDK tool-result files
When tool outputs are large, the SDK truncates them and saves the full output to
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
always use `read_file` or `Read` (NOT `read_workspace_file`).
`read_workspace_file` reads from cloud workspace storage, where SDK
tool-results are NOT stored.
always use `Read` (NOT `bash_exec`, NOT `read_workspace_file`).
These files are on the host filesystem — `bash_exec` runs in the sandbox and
CANNOT access them. `read_workspace_file` reads from cloud workspace storage,
where SDK tool-results are NOT stored.
{_SHARED_TOOL_NOTES}{extra_notes}"""

View File

@@ -36,6 +36,10 @@ class CoPilotUsageStatus(BaseModel):
daily: UsageWindow
weekly: UsageWindow
reset_cost: int = Field(
default=0,
description="Credit cost (in cents) to reset the daily limit. 0 = feature disabled.",
)
class RateLimitExceeded(Exception):
@@ -61,6 +65,7 @@ async def get_usage_status(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
rate_limit_reset_cost: int = 0,
) -> CoPilotUsageStatus:
"""Get current usage status for a user.
@@ -68,6 +73,7 @@ async def get_usage_status(
user_id: The user's ID.
daily_token_limit: Max tokens per day (0 = unlimited).
weekly_token_limit: Max tokens per week (0 = unlimited).
rate_limit_reset_cost: Credit cost (cents) to reset daily limit (0 = disabled).
Returns:
CoPilotUsageStatus with current usage and limits.
@@ -97,6 +103,7 @@ async def get_usage_status(
limit=weekly_token_limit,
resets_at=_weekly_reset_time(now=now),
),
reset_cost=rate_limit_reset_cost,
)
@@ -141,6 +148,110 @@ async def check_rate_limit(
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
async def reset_daily_usage(user_id: str, daily_token_limit: int = 0) -> bool:
"""Reset a user's daily token usage counter in Redis.
Called after a user pays credits to extend their daily limit.
Also reduces the weekly usage counter by ``daily_token_limit`` tokens
(clamped to 0) so the user effectively gets one extra day's worth of
weekly capacity.
Args:
user_id: The user's ID.
daily_token_limit: The configured daily token limit. When positive,
the weekly counter is reduced by this amount.
Fails open: returns False if Redis is unavailable (consistent with
the fail-open design of this module).
"""
now = datetime.now(UTC)
try:
redis = await get_redis_async()
# Use a MULTI/EXEC transaction so that DELETE (daily) and DECRBY
# (weekly) either both execute or neither does. This prevents the
# scenario where the daily counter is cleared but the weekly
# counter is not decremented — which would let the caller refund
# credits even though the daily limit was already reset.
d_key = _daily_key(user_id, now=now)
w_key = _weekly_key(user_id, now=now) if daily_token_limit > 0 else None
pipe = redis.pipeline(transaction=True)
pipe.delete(d_key)
if w_key is not None:
pipe.decrby(w_key, daily_token_limit)
results = await pipe.execute()
# Clamp negative weekly counter to 0 (best-effort; not critical).
if w_key is not None:
new_val = results[1] # DECRBY result
if new_val < 0:
await redis.set(w_key, 0, keepttl=True)
logger.info("Reset daily usage for user %s", user_id[:8])
return True
except (RedisError, ConnectionError, OSError):
logger.warning("Redis unavailable for resetting daily usage")
return False
_RESET_LOCK_PREFIX = "copilot:reset_lock"
_RESET_COUNT_PREFIX = "copilot:reset_count"
async def acquire_reset_lock(user_id: str, ttl_seconds: int = 10) -> bool:
"""Acquire a short-lived lock to serialize rate limit resets per user."""
try:
redis = await get_redis_async()
key = f"{_RESET_LOCK_PREFIX}:{user_id}"
return bool(await redis.set(key, "1", nx=True, ex=ttl_seconds))
except (RedisError, ConnectionError, OSError) as exc:
logger.warning("Redis unavailable for reset lock, rejecting reset: %s", exc)
return False
async def release_reset_lock(user_id: str) -> None:
"""Release the per-user reset lock."""
try:
redis = await get_redis_async()
await redis.delete(f"{_RESET_LOCK_PREFIX}:{user_id}")
except (RedisError, ConnectionError, OSError):
pass # Lock will expire via TTL
async def get_daily_reset_count(user_id: str) -> int | None:
"""Get how many times the user has reset today.
Returns None when Redis is unavailable so callers can fail-closed
for billed operations (as opposed to failing open for read-only
rate-limit checks).
"""
now = datetime.now(UTC)
try:
redis = await get_redis_async()
key = f"{_RESET_COUNT_PREFIX}:{user_id}:{now.strftime('%Y-%m-%d')}"
val = await redis.get(key)
return int(val or 0)
except (RedisError, ConnectionError, OSError):
logger.warning("Redis unavailable for reading daily reset count")
return None
async def increment_daily_reset_count(user_id: str) -> None:
"""Increment and track how many resets this user has done today."""
now = datetime.now(UTC)
try:
redis = await get_redis_async()
key = f"{_RESET_COUNT_PREFIX}:{user_id}:{now.strftime('%Y-%m-%d')}"
pipe = redis.pipeline(transaction=True)
pipe.incr(key)
seconds_until_reset = int((_daily_reset_time(now=now) - now).total_seconds())
pipe.expire(key, max(seconds_until_reset, 1))
await pipe.execute()
except (RedisError, ConnectionError, OSError):
logger.warning("Redis unavailable for tracking reset count")
async def record_token_usage(
user_id: str,
prompt_tokens: int,
@@ -231,6 +342,67 @@ async def record_token_usage(
)
async def get_global_rate_limits(
user_id: str,
config_daily: int,
config_weekly: int,
) -> tuple[int, int]:
"""Resolve global rate limits from LaunchDarkly, falling back to config.
Args:
user_id: User ID for LD flag evaluation context.
config_daily: Fallback daily limit from ChatConfig.
config_weekly: Fallback weekly limit from ChatConfig.
Returns:
(daily_token_limit, weekly_token_limit) tuple.
"""
# Lazy import to avoid circular dependency:
# rate_limit -> feature_flag -> settings -> ... -> rate_limit
from backend.util.feature_flag import Flag, get_feature_flag_value
daily_raw = await get_feature_flag_value(
Flag.COPILOT_DAILY_TOKEN_LIMIT.value, user_id, config_daily
)
weekly_raw = await get_feature_flag_value(
Flag.COPILOT_WEEKLY_TOKEN_LIMIT.value, user_id, config_weekly
)
try:
daily = max(0, int(daily_raw))
except (TypeError, ValueError):
logger.warning("Invalid LD value for daily token limit: %r", daily_raw)
daily = config_daily
try:
weekly = max(0, int(weekly_raw))
except (TypeError, ValueError):
logger.warning("Invalid LD value for weekly token limit: %r", weekly_raw)
weekly = config_weekly
return daily, weekly
async def reset_user_usage(user_id: str, *, reset_weekly: bool = False) -> None:
"""Reset a user's usage counters.
Always deletes the daily Redis key. When *reset_weekly* is ``True``,
the weekly key is deleted as well.
Unlike read paths (``get_usage_status``, ``check_rate_limit``) which
fail-open on Redis errors, resets intentionally re-raise so the caller
knows the operation did not succeed. A silent failure here would leave
the admin believing the counters were zeroed when they were not.
"""
now = datetime.now(UTC)
keys_to_delete = [_daily_key(user_id, now=now)]
if reset_weekly:
keys_to_delete.append(_weekly_key(user_id, now=now))
try:
redis = await get_redis_async()
await redis.delete(*keys_to_delete)
except (RedisError, ConnectionError, OSError):
logger.warning("Redis unavailable for resetting user usage")
raise
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------

View File

@@ -12,6 +12,7 @@ from .rate_limit import (
check_rate_limit,
get_usage_status,
record_token_usage,
reset_daily_usage,
)
_USER = "test-user-rl"
@@ -332,3 +333,91 @@ class TestRecordTokenUsage:
):
# Should not raise — fail-open
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
# ---------------------------------------------------------------------------
# reset_daily_usage
# ---------------------------------------------------------------------------
class TestResetDailyUsage:
@staticmethod
def _make_pipeline_mock(decrby_result: int = 0) -> MagicMock:
"""Create a pipeline mock that returns [delete_result, decrby_result]."""
pipe = MagicMock()
pipe.execute = AsyncMock(return_value=[1, decrby_result])
return pipe
@pytest.mark.asyncio
async def test_deletes_daily_key(self):
mock_pipe = self._make_pipeline_mock(decrby_result=0)
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
result = await reset_daily_usage(_USER, daily_token_limit=10000)
assert result is True
mock_pipe.delete.assert_called_once()
@pytest.mark.asyncio
async def test_reduces_weekly_usage_via_decrby(self):
"""Weekly counter should be reduced via DECRBY in the pipeline."""
mock_pipe = self._make_pipeline_mock(decrby_result=35000)
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await reset_daily_usage(_USER, daily_token_limit=10000)
mock_pipe.decrby.assert_called_once()
mock_redis.set.assert_not_called() # 35000 > 0, no clamp needed
@pytest.mark.asyncio
async def test_clamps_negative_weekly_to_zero(self):
"""If DECRBY goes negative, SET to 0 (outside the pipeline)."""
mock_pipe = self._make_pipeline_mock(decrby_result=-5000)
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await reset_daily_usage(_USER, daily_token_limit=10000)
mock_pipe.decrby.assert_called_once()
mock_redis.set.assert_called_once()
@pytest.mark.asyncio
async def test_no_weekly_reduction_when_daily_limit_zero(self):
"""When daily_token_limit is 0, weekly counter should not be touched."""
mock_pipe = self._make_pipeline_mock()
mock_pipe.execute = AsyncMock(return_value=[1]) # only delete result
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await reset_daily_usage(_USER, daily_token_limit=0)
mock_pipe.delete.assert_called_once()
mock_pipe.decrby.assert_not_called()
@pytest.mark.asyncio
async def test_returns_false_when_redis_unavailable(self):
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
result = await reset_daily_usage(_USER, daily_token_limit=10000)
assert result is False

View File

@@ -0,0 +1,294 @@
"""Unit tests for the POST /usage/reset endpoint."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
from backend.api.features.chat.routes import reset_copilot_usage
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
from backend.util.exceptions import InsufficientBalanceError
# Minimal config mock matching ChatConfig fields used by the endpoint.
def _make_config(
rate_limit_reset_cost: int = 500,
daily_token_limit: int = 2_500_000,
weekly_token_limit: int = 12_500_000,
max_daily_resets: int = 5,
):
cfg = MagicMock()
cfg.rate_limit_reset_cost = rate_limit_reset_cost
cfg.daily_token_limit = daily_token_limit
cfg.weekly_token_limit = weekly_token_limit
cfg.max_daily_resets = max_daily_resets
return cfg
def _usage(daily_used: int = 3_000_000, daily_limit: int = 2_500_000):
return CoPilotUsageStatus(
daily=UsageWindow(
used=daily_used,
limit=daily_limit,
resets_at=datetime.now(UTC) + timedelta(hours=6),
),
weekly=UsageWindow(
used=5_000_000,
limit=12_500_000,
resets_at=datetime.now(UTC) + timedelta(days=3),
),
)
_MODULE = "backend.api.features.chat.routes"
def _mock_settings(enable_credit: bool = True):
"""Return a mock Settings object with the given enable_credit flag."""
mock = MagicMock()
mock.config.enable_credit = enable_credit
return mock
@pytest.mark.asyncio
class TestResetCopilotUsage:
async def test_feature_disabled_returns_400(self):
"""When rate_limit_reset_cost=0, endpoint returns 400."""
with patch(f"{_MODULE}.config", _make_config(rate_limit_reset_cost=0)):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 400
assert "not available" in exc_info.value.detail
async def test_no_daily_limit_returns_400(self):
"""When daily_token_limit=0 (unlimited), endpoint returns 400."""
with (
patch(f"{_MODULE}.config", _make_config(daily_token_limit=0)),
patch(f"{_MODULE}.settings", _mock_settings()),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 400
assert "nothing to reset" in exc_info.value.detail.lower()
async def test_not_at_limit_returns_400(self):
"""When user hasn't hit their daily limit, returns 400."""
cfg = _make_config()
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
patch(
f"{_MODULE}.get_usage_status",
AsyncMock(return_value=_usage(daily_used=1_000_000)),
),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 400
assert "not reached" in exc_info.value.detail
mock_release.assert_awaited_once()
async def test_insufficient_credits_returns_402(self):
"""When user doesn't have enough credits, returns 402."""
mock_credit_model = AsyncMock()
mock_credit_model.spend_credits.side_effect = InsufficientBalanceError(
message="Insufficient balance",
user_id="user-1",
balance=50,
amount=200,
)
cfg = _make_config()
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
patch(
f"{_MODULE}.get_usage_status",
AsyncMock(return_value=_usage()),
),
patch(
f"{_MODULE}.get_user_credit_model",
AsyncMock(return_value=mock_credit_model),
),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 402
mock_release.assert_awaited_once()
async def test_happy_path(self):
"""Successful reset: charges credits, resets usage, returns response."""
mock_credit_model = AsyncMock()
mock_credit_model.spend_credits.return_value = 1500 # remaining balance
cfg = _make_config()
updated_usage = _usage(daily_used=0)
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
patch(
f"{_MODULE}.get_usage_status",
AsyncMock(side_effect=[_usage(), updated_usage]),
),
patch(
f"{_MODULE}.get_user_credit_model",
AsyncMock(return_value=mock_credit_model),
),
patch(
f"{_MODULE}.reset_daily_usage", AsyncMock(return_value=True)
) as mock_reset,
patch(f"{_MODULE}.increment_daily_reset_count", AsyncMock()) as mock_incr,
):
result = await reset_copilot_usage(user_id="user-1")
assert result.success is True
assert result.credits_charged == 500
assert result.remaining_balance == 1500
mock_reset.assert_awaited_once()
mock_incr.assert_awaited_once()
async def test_max_daily_resets_exceeded(self):
"""When user has exhausted daily resets, returns 429."""
cfg = _make_config(max_daily_resets=3)
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=3)),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 429
async def test_credit_system_disabled_returns_400(self):
"""When enable_credit=False, endpoint returns 400."""
with (
patch(f"{_MODULE}.config", _make_config()),
patch(f"{_MODULE}.settings", _mock_settings(enable_credit=False)),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 400
assert "credit system is disabled" in exc_info.value.detail.lower()
async def test_weekly_limit_exhausted_returns_400(self):
"""When the weekly limit is also exhausted, resetting daily won't help."""
cfg = _make_config()
weekly_exhausted = CoPilotUsageStatus(
daily=UsageWindow(
used=3_000_000,
limit=2_500_000,
resets_at=datetime.now(UTC) + timedelta(hours=6),
),
weekly=UsageWindow(
used=12_500_000,
limit=12_500_000,
resets_at=datetime.now(UTC) + timedelta(days=3),
),
)
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
patch(
f"{_MODULE}.get_usage_status",
AsyncMock(return_value=weekly_exhausted),
),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 400
assert "weekly" in exc_info.value.detail.lower()
mock_release.assert_awaited_once()
async def test_redis_failure_for_reset_count_returns_503(self):
"""When Redis is unavailable for get_daily_reset_count, returns 503."""
with (
patch(f"{_MODULE}.config", _make_config()),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=None)),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 503
assert "verify" in exc_info.value.detail.lower()
async def test_redis_reset_failure_refunds_credits(self):
"""When reset_daily_usage fails, credits are refunded and 503 returned."""
mock_credit_model = AsyncMock()
mock_credit_model.spend_credits.return_value = 1500
cfg = _make_config()
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
patch(
f"{_MODULE}.get_usage_status",
AsyncMock(return_value=_usage()),
),
patch(
f"{_MODULE}.get_user_credit_model",
AsyncMock(return_value=mock_credit_model),
),
patch(f"{_MODULE}.reset_daily_usage", AsyncMock(return_value=False)),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 503
assert "not been charged" in exc_info.value.detail
mock_credit_model.top_up_credits.assert_awaited_once()
async def test_redis_reset_failure_refund_also_fails(self):
"""When both reset and refund fail, error message reflects the truth."""
mock_credit_model = AsyncMock()
mock_credit_model.spend_credits.return_value = 1500
mock_credit_model.top_up_credits.side_effect = RuntimeError("db down")
cfg = _make_config()
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
patch(
f"{_MODULE}.get_usage_status",
AsyncMock(return_value=_usage()),
),
patch(
f"{_MODULE}.get_user_credit_model",
AsyncMock(return_value=mock_credit_model),
),
patch(f"{_MODULE}.reset_daily_usage", AsyncMock(return_value=False)),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
assert exc_info.value.status_code == 503
assert "contact support" in exc_info.value.detail.lower()

View File

@@ -67,9 +67,17 @@ These define the agent's interface — what it accepts and what it produces.
**AgentInputBlock** (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`):
- Defines a user-facing input field on the agent
- Required `input_default` fields: `name` (str), `value` (default: null)
- Optional: `title`, `description`, `placeholder_values` (for dropdowns)
- Optional: `title`, `description`
- Output: `result` — the user-provided value at runtime
- Create one AgentInputBlock per distinct input the agent needs
- For dropdown/select inputs, use **AgentDropdownInputBlock** instead (see below)
**AgentDropdownInputBlock** (ID: `655d6fdf-a334-421c-b733-520549c07cd1`):
- Specialized input block that presents a dropdown/select to the user
- Required `input_default` fields: `name` (str), `placeholder_values` (list of options, must have at least one)
- Optional: `title`, `description`, `value` (default selection)
- Output: `result` — the user-selected value at runtime
- Use this instead of AgentInputBlock when the user should pick from a fixed set of options
**AgentOutputBlock** (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`):
- Defines a user-facing output displayed after the agent runs
@@ -208,6 +216,20 @@ call in a loop until the task is complete:
Regular blocks work exactly like sub-agents as tools — wire each input
field from `source_name: "tools"` on the Orchestrator side.
### Testing with Dry Run
After saving an agent, suggest a dry run to validate wiring without consuming
real API calls, credentials, or credits:
1. **Run**: Call `run_agent` or `run_block` with `dry_run=True` and provide
sample inputs. This executes the graph with mock outputs, verifying that
links resolve correctly and required inputs are satisfied.
2. **Check results**: Call `view_agent_output` with `show_execution_details=True`
to inspect the full node-by-node execution trace. This shows what each node
received as input and produced as output, making it easy to spot wiring issues.
3. **Iterate**: If the dry run reveals wiring issues or missing inputs, fix
the agent JSON and re-save before suggesting a real execution.
### Example: Simple AI Text Processor
A minimal agent with input, processing, and output:

View File

@@ -7,20 +7,35 @@ without implementing their own event loop.
from __future__ import annotations
import logging
import uuid
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any
from backend.copilot.response_model import (
if TYPE_CHECKING:
from backend.copilot.permissions import CopilotPermissions
from pydantic import BaseModel, Field
from redis.exceptions import RedisError
from .. import stream_registry
from ..response_model import (
StreamError,
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
StreamUsage,
)
from .service import stream_chat_completion_sdk
if TYPE_CHECKING:
from backend.copilot.permissions import CopilotPermissions
logger = logging.getLogger(__name__)
# Identifiers used when registering AutoPilot-originated streams in the
# stream registry. Distinct from "chat_stream"/"chat" used by the HTTP SSE
# endpoint, making it easy to filter AutoPilot streams in logs/observability.
AUTOPILOT_TOOL_CALL_ID = "autopilot_stream"
AUTOPILOT_TOOL_NAME = "autopilot"
class CopilotResult:
@@ -46,6 +61,111 @@ class CopilotResult:
self.total_tokens: int = 0
class _RegistryHandle(BaseModel):
"""Tracks stream registry session state for cleanup."""
publish_turn_id: str = ""
error_msg: str | None = None
error_already_published: bool = False
@asynccontextmanager
async def _registry_session(
session_id: str, user_id: str, turn_id: str
) -> AsyncIterator[_RegistryHandle]:
"""Create a stream registry session and ensure it is finalized."""
handle = _RegistryHandle(publish_turn_id=turn_id)
try:
await stream_registry.create_session(
session_id=session_id,
user_id=user_id,
tool_call_id=AUTOPILOT_TOOL_CALL_ID,
tool_name=AUTOPILOT_TOOL_NAME,
turn_id=turn_id,
)
except (RedisError, ConnectionError, OSError):
logger.warning(
"[collect] Failed to create stream registry session for %s, "
"frontend will not receive real-time updates",
session_id[:12],
exc_info=True,
)
# Disable chunk publishing but keep finalization enabled so
# mark_session_completed can clean up any partial registry state.
handle.publish_turn_id = ""
try:
yield handle
finally:
try:
await stream_registry.mark_session_completed(
session_id,
error_message=handle.error_msg,
skip_error_publish=handle.error_already_published,
)
except (RedisError, ConnectionError, OSError):
logger.warning(
"[collect] Failed to mark stream completed for %s",
session_id[:12],
exc_info=True,
)
class _ToolCallEntry(BaseModel):
"""A single tool call observed during stream consumption."""
tool_call_id: str
tool_name: str
input: Any
output: Any = None
success: bool | None = None
class _EventAccumulator(BaseModel):
"""Mutable accumulator for stream events."""
response_parts: list[str] = Field(default_factory=list)
tool_calls: list[_ToolCallEntry] = Field(default_factory=list)
tool_calls_by_id: dict[str, _ToolCallEntry] = Field(default_factory=dict)
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
def _process_event(event: object, acc: _EventAccumulator) -> str | None:
"""Process a single stream event and return error_msg if StreamError.
Uses structural pattern matching for dispatch per project guidelines.
"""
match event:
case StreamTextDelta(delta=delta):
acc.response_parts.append(delta)
case StreamToolInputAvailable() as e:
entry = _ToolCallEntry(
tool_call_id=e.toolCallId,
tool_name=e.toolName,
input=e.input,
)
acc.tool_calls.append(entry)
acc.tool_calls_by_id[e.toolCallId] = entry
case StreamToolOutputAvailable() as e:
if tc := acc.tool_calls_by_id.get(e.toolCallId):
tc.output = e.output
tc.success = e.success
else:
logger.debug(
"Received tool output for unknown tool_call_id: %s",
e.toolCallId,
)
case StreamUsage() as e:
acc.prompt_tokens += e.prompt_tokens
acc.completion_tokens += e.completion_tokens
acc.total_tokens += e.total_tokens
case StreamError(errorText=err):
return err
return None
async def collect_copilot_response(
*,
session_id: str,
@@ -56,11 +176,8 @@ async def collect_copilot_response(
) -> CopilotResult:
"""Consume :func:`stream_chat_completion_sdk` and return aggregated results.
This is the recommended entry-point for callers that need a simple
request-response interface (e.g. the AutoPilot block) rather than
streaming individual events. It avoids duplicating the event-collection
logic and does NOT wrap the stream in ``asyncio.timeout`` — the SDK
manages its own heartbeat-based timeouts internally.
Registers with the stream registry so the frontend can connect via SSE
and receive real-time updates while the AutoPilot block is executing.
Args:
session_id: Chat session to use.
@@ -77,39 +194,39 @@ async def collect_copilot_response(
Raises:
RuntimeError: If the stream yields a ``StreamError`` event.
"""
turn_id = str(uuid.uuid4())
async with _registry_session(session_id, user_id, turn_id) as handle:
try:
raw_stream = stream_chat_completion_sdk(
session_id=session_id,
message=message,
is_user_message=is_user_message,
user_id=user_id,
permissions=permissions,
)
published_stream = stream_registry.stream_and_publish(
session_id=session_id,
turn_id=handle.publish_turn_id,
stream=raw_stream,
)
acc = _EventAccumulator()
async for event in published_stream:
if err := _process_event(event, acc):
handle.error_msg = err
# stream_and_publish skips StreamError events, so
# mark_session_completed must publish the error to Redis.
handle.error_already_published = False
raise RuntimeError(f"Copilot error: {err}")
except Exception:
if handle.error_msg is None:
handle.error_msg = "AutoPilot execution failed"
raise
result = CopilotResult()
response_parts: list[str] = []
tool_calls_by_id: dict[str, dict[str, Any]] = {}
async for event in stream_chat_completion_sdk(
session_id=session_id,
message=message,
is_user_message=is_user_message,
user_id=user_id,
permissions=permissions,
):
if isinstance(event, StreamTextDelta):
response_parts.append(event.delta)
elif isinstance(event, StreamToolInputAvailable):
entry: dict[str, Any] = {
"tool_call_id": event.toolCallId,
"tool_name": event.toolName,
"input": event.input,
"output": None,
"success": None,
}
result.tool_calls.append(entry)
tool_calls_by_id[event.toolCallId] = entry
elif isinstance(event, StreamToolOutputAvailable):
if tc := tool_calls_by_id.get(event.toolCallId):
tc["output"] = event.output
tc["success"] = event.success
elif isinstance(event, StreamUsage):
result.prompt_tokens += event.prompt_tokens
result.completion_tokens += event.completion_tokens
result.total_tokens += event.total_tokens
elif isinstance(event, StreamError):
raise RuntimeError(f"Copilot error: {event.errorText}")
result.response_text = "".join(response_parts)
result.response_text = "".join(acc.response_parts)
result.tool_calls = [tc.model_dump() for tc in acc.tool_calls]
result.prompt_tokens = acc.prompt_tokens
result.completion_tokens = acc.completion_tokens
result.total_tokens = acc.total_tokens
return result

View File

@@ -0,0 +1,177 @@
"""Tests for collect_copilot_response stream registry integration."""
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.response_model import (
StreamError,
StreamFinish,
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
StreamUsage,
)
from backend.copilot.sdk.collect import collect_copilot_response
def _mock_stream_fn(*events):
"""Return a callable that returns an async generator."""
async def _gen(**_kwargs):
for e in events:
yield e
return _gen
@pytest.fixture
def mock_registry():
"""Patch stream_registry module used by collect."""
with patch("backend.copilot.sdk.collect.stream_registry") as m:
m.create_session = AsyncMock()
m.publish_chunk = AsyncMock()
m.mark_session_completed = AsyncMock()
# stream_and_publish: pass-through that also publishes (real logic)
# We re-implement the pass-through here so the event loop works,
# but still track publish_chunk calls via the mock.
async def _stream_and_publish(session_id, turn_id, stream):
async for event in stream:
if turn_id and not isinstance(event, (StreamFinish, StreamError)):
await m.publish_chunk(turn_id, event)
yield event
m.stream_and_publish = _stream_and_publish
yield m
@pytest.fixture
def stream_fn_patch():
"""Helper to patch stream_chat_completion_sdk."""
def _patch(events):
return patch(
"backend.copilot.sdk.collect.stream_chat_completion_sdk",
new=_mock_stream_fn(*events),
)
return _patch
@pytest.mark.asyncio
async def test_stream_registry_called_on_success(mock_registry, stream_fn_patch):
"""Stream registry create/publish/complete are called correctly on success."""
events = [
StreamTextDelta(id="t1", delta="Hello "),
StreamTextDelta(id="t1", delta="world"),
StreamUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
StreamFinish(),
]
with stream_fn_patch(events):
result = await collect_copilot_response(
session_id="test-session",
message="hi",
user_id="user-1",
)
assert result.response_text == "Hello world"
assert result.total_tokens == 15
mock_registry.create_session.assert_awaited_once()
# StreamFinish should NOT be published (mark_session_completed does it)
published_types = [
type(call.args[1]).__name__
for call in mock_registry.publish_chunk.call_args_list
]
assert "StreamFinish" not in published_types
assert "StreamTextDelta" in published_types
mock_registry.mark_session_completed.assert_awaited_once()
_, kwargs = mock_registry.mark_session_completed.call_args
assert kwargs.get("error_message") is None
@pytest.mark.asyncio
async def test_stream_registry_error_on_stream_error(mock_registry, stream_fn_patch):
"""mark_session_completed receives error message when StreamError occurs."""
events = [
StreamTextDelta(id="t1", delta="partial"),
StreamError(errorText="something broke"),
]
with stream_fn_patch(events):
with pytest.raises(RuntimeError, match="something broke"):
await collect_copilot_response(
session_id="test-session",
message="hi",
user_id="user-1",
)
_, kwargs = mock_registry.mark_session_completed.call_args
assert kwargs.get("error_message") == "something broke"
# stream_and_publish skips StreamError, so mark_session_completed must
# publish it (skip_error_publish=False).
assert kwargs.get("skip_error_publish") is False
# StreamError should NOT be published via publish_chunk — mark_session_completed
# handles it to avoid double-publication.
published_types = [
type(call.args[1]).__name__
for call in mock_registry.publish_chunk.call_args_list
]
assert "StreamError" not in published_types
@pytest.mark.asyncio
async def test_graceful_degradation_when_create_session_fails(
mock_registry, stream_fn_patch
):
"""AutoPilot still works when stream registry create_session raises."""
events = [
StreamTextDelta(id="t1", delta="works"),
StreamFinish(),
]
mock_registry.create_session = AsyncMock(side_effect=ConnectionError("Redis down"))
with stream_fn_patch(events):
result = await collect_copilot_response(
session_id="test-session",
message="hi",
user_id="user-1",
)
assert result.response_text == "works"
# publish_chunk should NOT be called because turn_id was cleared
mock_registry.publish_chunk.assert_not_awaited()
# mark_session_completed IS still called to clean up any partial state
mock_registry.mark_session_completed.assert_awaited_once()
@pytest.mark.asyncio
async def test_tool_calls_published_and_collected(mock_registry, stream_fn_patch):
"""Tool call events are both published to registry and collected in result."""
events = [
StreamToolInputAvailable(
toolCallId="tc-1", toolName="read_file", input={"path": "/tmp"}
),
StreamToolOutputAvailable(
toolCallId="tc-1", output="file contents", success=True
),
StreamTextDelta(id="t1", delta="done"),
StreamFinish(),
]
with stream_fn_patch(events):
result = await collect_copilot_response(
session_id="test-session",
message="hi",
user_id="user-1",
)
assert len(result.tool_calls) == 1
assert result.tool_calls[0]["tool_name"] == "read_file"
assert result.tool_calls[0]["output"] == "file contents"
assert result.tool_calls[0]["success"] is True
assert result.response_text == "done"

View File

@@ -25,24 +25,64 @@ def build_test_transcript(pairs: list[tuple[str, str]]) -> str:
Use this helper in any copilot SDK test that needs a well-formed
transcript without hitting the real storage layer.
Delegates to ``build_structured_transcript`` — plain content strings
are automatically wrapped in ``[{"type": "text", "text": ...}]`` for
assistant messages.
"""
# Cast widening: tuple[str, str] is structurally compatible with
# tuple[str, str | list[dict]] but list invariance requires explicit
# annotation.
widened: list[tuple[str, str | list[dict]]] = list(pairs)
return build_structured_transcript(widened)
def build_structured_transcript(
entries: list[tuple[str, str | list[dict]]],
) -> str:
"""Build a JSONL transcript with structured content blocks.
Each entry is (role, content) where content is either a plain string
(for user messages) or a list of content block dicts (for assistant
messages with thinking/tool_use/text blocks).
Example::
build_structured_transcript([
("user", "Hello"),
("assistant", [
{"type": "thinking", "thinking": "...", "signature": "sig1"},
{"type": "text", "text": "Hi there"},
]),
])
"""
lines: list[str] = []
last_uuid: str | None = None
for role, content in pairs:
for role, content in entries:
uid = str(uuid4())
entry_type = "assistant" if role == "assistant" else "user"
msg: dict = {"role": role, "content": content}
if role == "assistant":
msg.update(
{
"model": "",
"id": f"msg_{uid[:8]}",
"type": "message",
"content": [{"type": "text", "text": content}],
"stop_reason": "end_turn",
"stop_sequence": None,
}
)
if role == "assistant" and isinstance(content, list):
msg: dict = {
"role": "assistant",
"model": "claude-test",
"id": f"msg_{uid[:8]}",
"type": "message",
"content": content,
"stop_reason": "end_turn",
"stop_sequence": None,
}
elif role == "assistant":
msg = {
"role": "assistant",
"model": "claude-test",
"id": f"msg_{uid[:8]}",
"type": "message",
"content": [{"type": "text", "text": content}],
"stop_reason": "end_turn",
"stop_sequence": None,
}
else:
msg = {"role": role, "content": content}
entry = {
"type": entry_type,
"uuid": uid,

View File

@@ -2,7 +2,7 @@
When E2B is active, these tools replace the SDK built-in Read/Write/Edit/
Glob/Grep so that all file operations share the same ``/home/user``
filesystem as ``bash_exec``.
and ``/tmp`` filesystems as ``bash_exec``.
SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
@@ -16,10 +16,13 @@ import shlex
from typing import Any, Callable
from backend.copilot.context import (
E2B_ALLOWED_DIRS,
E2B_ALLOWED_DIRS_STR,
E2B_WORKDIR,
get_current_sandbox,
get_sdk_cwd,
is_allowed_local_path,
is_within_allowed_dirs,
resolve_sandbox_path,
)
@@ -36,7 +39,7 @@ async def _check_sandbox_symlink_escape(
``readlink -f`` follows actual symlinks on the sandbox filesystem.
Returns the canonical parent path, or ``None`` if the path escapes
``E2B_WORKDIR``.
the allowed sandbox directories.
Note: There is an inherent TOCTOU window between this check and the
subsequent ``sandbox.files.write()``. A symlink could theoretically be
@@ -52,10 +55,7 @@ async def _check_sandbox_symlink_escape(
if (
canonical_res.exit_code != 0
or not canonical_parent
or (
canonical_parent != E2B_WORKDIR
and not canonical_parent.startswith(E2B_WORKDIR + "/")
)
or not is_within_allowed_dirs(canonical_parent)
):
return None
return canonical_parent
@@ -89,6 +89,38 @@ def _get_sandbox_and_path(
return sandbox, remote
async def _sandbox_write(sandbox: Any, path: str, content: str) -> None:
"""Write *content* to *path* inside the sandbox.
The E2B filesystem API (``sandbox.files.write``) and the command API
(``sandbox.commands.run``) run as **different users**. On ``/tmp``
(which has the sticky bit set) this means ``sandbox.files.write`` can
create new files but cannot overwrite files previously created by
``sandbox.commands.run`` (or itself), because the sticky bit restricts
deletion/rename to the file owner.
To work around this, writes targeting ``/tmp`` are performed via
``tee`` through the command API, which runs as the sandbox ``user``
and can therefore always overwrite user-owned files.
"""
if path == "/tmp" or path.startswith("/tmp/"):
import base64 as _b64
encoded = _b64.b64encode(content.encode()).decode()
result = await sandbox.commands.run(
f"echo {shlex.quote(encoded)} | base64 -d > {shlex.quote(path)}",
cwd=E2B_WORKDIR,
timeout=10,
)
if result.exit_code != 0:
raise RuntimeError(
f"shell write failed (exit {result.exit_code}): "
+ (result.stderr or "").strip()
)
else:
await sandbox.files.write(path, content)
# Tool handlers
@@ -139,13 +171,16 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
try:
parent = os.path.dirname(remote)
if parent and parent != E2B_WORKDIR:
if parent and parent not in E2B_ALLOWED_DIRS:
await sandbox.files.make_dir(parent)
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
if canonical_parent is None:
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
return _mcp(
f"Path must be within {E2B_ALLOWED_DIRS_STR}: {os.path.basename(parent)}",
error=True,
)
remote = os.path.join(canonical_parent, os.path.basename(remote))
await sandbox.files.write(remote, content)
await _sandbox_write(sandbox, remote, content)
except Exception as exc:
return _mcp(f"Failed to write {remote}: {exc}", error=True)
@@ -172,7 +207,10 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
parent = os.path.dirname(remote)
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
if canonical_parent is None:
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
return _mcp(
f"Path must be within {E2B_ALLOWED_DIRS_STR}: {os.path.basename(parent)}",
error=True,
)
remote = os.path.join(canonical_parent, os.path.basename(remote))
try:
@@ -197,7 +235,7 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
else content.replace(old_string, new_string, 1)
)
try:
await sandbox.files.write(remote, updated)
await _sandbox_write(sandbox, remote, updated)
except Exception as exc:
return _mcp(f"Failed to write {remote}: {exc}", error=True)
@@ -290,14 +328,14 @@ def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
(
"read_file",
"Read a file from the cloud sandbox (/home/user). "
"Read a file from the cloud sandbox (/home/user or /tmp). "
"Use offset and limit for large files.",
{
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path (relative to /home/user, or absolute).",
"description": "Path (relative to /home/user, or absolute under /home/user or /tmp).",
},
"offset": {
"type": "integer",
@@ -314,7 +352,7 @@ E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
),
(
"write_file",
"Write or create a file in the cloud sandbox (/home/user). "
"Write or create a file in the cloud sandbox (/home/user or /tmp). "
"Parent directories are created automatically. "
"To copy a workspace file into the sandbox, use "
"read_workspace_file with save_to_path instead.",
@@ -323,7 +361,7 @@ E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
"properties": {
"file_path": {
"type": "string",
"description": "Path (relative to /home/user, or absolute).",
"description": "Path (relative to /home/user, or absolute under /home/user or /tmp).",
},
"content": {"type": "string", "description": "Content to write."},
},
@@ -340,7 +378,7 @@ E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
"properties": {
"file_path": {
"type": "string",
"description": "Path (relative to /home/user, or absolute).",
"description": "Path (relative to /home/user, or absolute under /home/user or /tmp).",
},
"old_string": {"type": "string", "description": "Text to find."},
"new_string": {"type": "string", "description": "Replacement text."},

View File

@@ -15,6 +15,7 @@ from backend.copilot.context import E2B_WORKDIR, SDK_PROJECTS_DIR, _current_proj
from .e2b_file_tools import (
_check_sandbox_symlink_escape,
_read_local,
_sandbox_write,
resolve_sandbox_path,
)
@@ -39,23 +40,23 @@ class TestResolveSandboxPath:
assert resolve_sandbox_path("./README.md") == f"{E2B_WORKDIR}/README.md"
def test_traversal_blocked(self):
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
with pytest.raises(ValueError, match="must be within"):
resolve_sandbox_path("../../etc/passwd")
def test_absolute_traversal_blocked(self):
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
with pytest.raises(ValueError, match="must be within"):
resolve_sandbox_path(f"{E2B_WORKDIR}/../../etc/passwd")
def test_absolute_outside_sandbox_blocked(self):
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
with pytest.raises(ValueError, match="must be within"):
resolve_sandbox_path("/etc/passwd")
def test_root_blocked(self):
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
with pytest.raises(ValueError, match="must be within"):
resolve_sandbox_path("/")
def test_home_other_user_blocked(self):
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
with pytest.raises(ValueError, match="must be within"):
resolve_sandbox_path("/home/other/file.txt")
def test_deep_nested_allowed(self):
@@ -68,6 +69,24 @@ class TestResolveSandboxPath:
"""Path that resolves back within E2B_WORKDIR is allowed."""
assert resolve_sandbox_path("a/b/../c.txt") == f"{E2B_WORKDIR}/a/c.txt"
def test_tmp_absolute_allowed(self):
assert resolve_sandbox_path("/tmp/data.txt") == "/tmp/data.txt"
def test_tmp_nested_allowed(self):
assert resolve_sandbox_path("/tmp/a/b/c.txt") == "/tmp/a/b/c.txt"
def test_tmp_itself_allowed(self):
assert resolve_sandbox_path("/tmp") == "/tmp"
def test_tmp_escape_blocked(self):
with pytest.raises(ValueError, match="must be within"):
resolve_sandbox_path("/tmp/../etc/passwd")
def test_tmp_prefix_collision_blocked(self):
"""A path like /tmp_evil should be blocked (not a prefix match)."""
with pytest.raises(ValueError, match="must be within"):
resolve_sandbox_path("/tmp_evil/malicious.txt")
# ---------------------------------------------------------------------------
# _read_local — host filesystem reads with allowlist enforcement
@@ -227,3 +246,92 @@ class TestCheckSandboxSymlinkEscape:
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/a/b/c/d\n", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/a/b/c/d")
assert result == f"{E2B_WORKDIR}/a/b/c/d"
@pytest.mark.asyncio
async def test_tmp_path_allowed(self):
"""Paths resolving to /tmp are allowed."""
sandbox = _make_sandbox(stdout="/tmp/workdir\n", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, "/tmp/workdir")
assert result == "/tmp/workdir"
@pytest.mark.asyncio
async def test_tmp_itself_allowed(self):
"""The /tmp directory itself is allowed."""
sandbox = _make_sandbox(stdout="/tmp\n", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, "/tmp")
assert result == "/tmp"
# ---------------------------------------------------------------------------
# _sandbox_write — routing writes through shell for /tmp paths
# ---------------------------------------------------------------------------
class TestSandboxWrite:
@pytest.mark.asyncio
async def test_tmp_path_uses_shell_command(self):
"""Writes to /tmp should use commands.run (shell) instead of files.write."""
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
files = SimpleNamespace(write=AsyncMock())
sandbox = SimpleNamespace(commands=commands, files=files)
await _sandbox_write(sandbox, "/tmp/test.py", "print('hello')")
commands.run.assert_called_once()
files.write.assert_not_called()
@pytest.mark.asyncio
async def test_home_user_path_uses_files_api(self):
"""Writes to /home/user should use sandbox.files.write."""
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
files = SimpleNamespace(write=AsyncMock())
sandbox = SimpleNamespace(commands=commands, files=files)
await _sandbox_write(sandbox, "/home/user/test.py", "print('hello')")
files.write.assert_called_once_with("/home/user/test.py", "print('hello')")
commands.run.assert_not_called()
@pytest.mark.asyncio
async def test_tmp_nested_path_uses_shell_command(self):
"""Writes to nested /tmp paths should use commands.run."""
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
files = SimpleNamespace(write=AsyncMock())
sandbox = SimpleNamespace(commands=commands, files=files)
await _sandbox_write(sandbox, "/tmp/subdir/file.txt", "content")
commands.run.assert_called_once()
files.write.assert_not_called()
@pytest.mark.asyncio
async def test_tmp_write_shell_failure_raises(self):
"""Shell write failure should raise RuntimeError."""
run_result = SimpleNamespace(stdout="", stderr="No space left", exit_code=1)
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
sandbox = SimpleNamespace(commands=commands)
with pytest.raises(RuntimeError, match="shell write failed"):
await _sandbox_write(sandbox, "/tmp/test.txt", "content")
@pytest.mark.asyncio
async def test_tmp_write_preserves_content_with_special_chars(self):
"""Content with special shell characters should be preserved via base64."""
import base64
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
sandbox = SimpleNamespace(commands=commands)
content = "print(\"Hello $USER\")\n# a `backtick` and 'quotes'\n"
await _sandbox_write(sandbox, "/tmp/special.py", content)
# Verify the command contains base64-encoded content
call_args = commands.run.call_args[0][0]
# Extract the base64 string from the command
encoded_in_cmd = call_args.split("echo ")[1].split(" |")[0].strip("'")
decoded = base64.b64decode(encoded_in_cmd).decode()
assert decoded == content

View File

@@ -0,0 +1,68 @@
"""SDK environment variable builder — importable without circular deps.
Extracted from ``service.py`` so that ``backend.blocks.orchestrator``
can reuse the same subscription / OpenRouter / direct-Anthropic logic
without pulling in the full copilot service module (which would create a
circular import through ``executor`` → ``credit`` → ``block_cost_config``).
"""
from __future__ import annotations
from backend.copilot.config import ChatConfig
from backend.copilot.sdk.subscription import validate_subscription
# ChatConfig is stateless (reads env vars) — a separate instance is fine.
# A singleton would require importing service.py which causes the circular dep
# this module was created to avoid.
config = ChatConfig()
def build_sdk_env(
session_id: str | None = None,
user_id: str | None = None,
) -> dict[str, str]:
"""Build env vars for the SDK CLI subprocess.
Three modes (checked in order):
1. **Subscription** — clears all keys; CLI uses ``claude login`` auth.
2. **Direct Anthropic** — returns ``{}``; subprocess inherits
``ANTHROPIC_API_KEY`` from the parent environment.
3. **OpenRouter** (default) — overrides base URL and auth token to
route through the proxy, with Langfuse trace headers.
"""
# --- Mode 1: Claude Code subscription auth ---
if config.use_claude_code_subscription:
validate_subscription()
return {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}
# --- Mode 2: Direct Anthropic (no proxy hop) ---
if not config.openrouter_active:
return {}
# --- Mode 3: OpenRouter proxy ---
base = (config.base_url or "").rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
env: dict[str, str] = {
"ANTHROPIC_BASE_URL": base,
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
}
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
def _safe(v: str) -> str:
return v.replace("\r", "").replace("\n", "").strip()[:128]
parts = []
if session_id:
parts.append(f"x-session-id: {_safe(session_id)}")
if user_id:
parts.append(f"x-user-id: {_safe(user_id)}")
if parts:
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
return env

View File

@@ -0,0 +1,242 @@
"""Tests for build_sdk_env() — the SDK subprocess environment builder."""
from unittest.mock import patch
import pytest
from backend.copilot.config import ChatConfig
# ---------------------------------------------------------------------------
# Helpers — build a ChatConfig with explicit field values so tests don't
# depend on real environment variables.
# ---------------------------------------------------------------------------
def _make_config(**overrides) -> ChatConfig:
"""Create a ChatConfig with safe defaults, applying *overrides*."""
defaults = {
"use_claude_code_subscription": False,
"use_openrouter": False,
"api_key": None,
"base_url": None,
}
defaults.update(overrides)
return ChatConfig(**defaults)
# ---------------------------------------------------------------------------
# Mode 1 — Subscription auth
# ---------------------------------------------------------------------------
class TestBuildSdkEnvSubscription:
"""When ``use_claude_code_subscription`` is True, keys are blanked."""
@patch("backend.copilot.sdk.env.validate_subscription")
def test_returns_blanked_keys(self, mock_validate):
"""Subscription mode clears API_KEY, AUTH_TOKEN, and BASE_URL."""
cfg = _make_config(use_claude_code_subscription=True)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result == {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}
mock_validate.assert_called_once()
@patch(
"backend.copilot.sdk.env.validate_subscription",
side_effect=RuntimeError("CLI not found"),
)
def test_propagates_validation_error(self, mock_validate):
"""If validate_subscription fails, the error bubbles up."""
cfg = _make_config(use_claude_code_subscription=True)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
with pytest.raises(RuntimeError, match="CLI not found"):
build_sdk_env()
# ---------------------------------------------------------------------------
# Mode 2 — Direct Anthropic (no OpenRouter)
# ---------------------------------------------------------------------------
class TestBuildSdkEnvDirectAnthropic:
"""When OpenRouter is inactive, return empty dict (inherit parent env)."""
def test_returns_empty_dict_when_openrouter_inactive(self):
cfg = _make_config(use_openrouter=False)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result == {}
def test_returns_empty_dict_when_openrouter_flag_true_but_no_key(self):
"""OpenRouter flag is True but no api_key => openrouter_active is False."""
cfg = _make_config(use_openrouter=True, base_url="https://openrouter.ai/api/v1")
# Force api_key to None after construction (field_validator may pick up env vars)
object.__setattr__(cfg, "api_key", None)
assert not cfg.openrouter_active
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result == {}
# ---------------------------------------------------------------------------
# Mode 3 — OpenRouter proxy
# ---------------------------------------------------------------------------
class TestBuildSdkEnvOpenRouter:
"""When OpenRouter is active, return proxy env vars."""
def _openrouter_config(self, **overrides):
defaults = {
"use_openrouter": True,
"api_key": "sk-or-test-key",
"base_url": "https://openrouter.ai/api/v1",
}
defaults.update(overrides)
return _make_config(**defaults)
def test_basic_openrouter_env(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
assert result["ANTHROPIC_AUTH_TOKEN"] == "sk-or-test-key"
assert result["ANTHROPIC_API_KEY"] == ""
assert "ANTHROPIC_CUSTOM_HEADERS" not in result
def test_strips_trailing_v1(self):
"""The /v1 suffix is stripped from the base URL."""
cfg = self._openrouter_config(base_url="https://openrouter.ai/api/v1")
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
def test_strips_trailing_v1_and_slash(self):
"""Trailing slash before /v1 strip is handled."""
cfg = self._openrouter_config(base_url="https://openrouter.ai/api/v1/")
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
# rstrip("/") first, then remove /v1
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
def test_no_v1_suffix_left_alone(self):
"""A base URL without /v1 is used as-is."""
cfg = self._openrouter_config(base_url="https://custom-proxy.example.com")
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
assert result["ANTHROPIC_BASE_URL"] == "https://custom-proxy.example.com"
def test_session_id_header(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(session_id="sess-123")
assert "ANTHROPIC_CUSTOM_HEADERS" in result
assert "x-session-id: sess-123" in result["ANTHROPIC_CUSTOM_HEADERS"]
def test_user_id_header(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(user_id="user-456")
assert "x-user-id: user-456" in result["ANTHROPIC_CUSTOM_HEADERS"]
def test_both_headers(self):
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(session_id="s1", user_id="u2")
headers = result["ANTHROPIC_CUSTOM_HEADERS"]
assert "x-session-id: s1" in headers
assert "x-user-id: u2" in headers
# They should be newline-separated
assert "\n" in headers
def test_header_sanitisation_strips_newlines(self):
"""Newlines/carriage-returns in header values are stripped."""
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env(session_id="bad\r\nvalue")
header_val = result["ANTHROPIC_CUSTOM_HEADERS"]
# The _safe helper removes \r and \n
assert "\r" not in header_val.split(": ", 1)[1]
assert "badvalue" in header_val
def test_header_value_truncated_to_128_chars(self):
"""Header values are truncated to 128 characters."""
cfg = self._openrouter_config()
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
long_id = "x" * 200
result = build_sdk_env(session_id=long_id)
# The value after "x-session-id: " should be at most 128 chars
header_line = result["ANTHROPIC_CUSTOM_HEADERS"]
value = header_line.split(": ", 1)[1]
assert len(value) == 128
# ---------------------------------------------------------------------------
# Mode priority
# ---------------------------------------------------------------------------
class TestBuildSdkEnvModePriority:
"""Subscription mode takes precedence over OpenRouter."""
@patch("backend.copilot.sdk.env.validate_subscription")
def test_subscription_overrides_openrouter(self, mock_validate):
cfg = _make_config(
use_claude_code_subscription=True,
use_openrouter=True,
api_key="sk-or-key",
base_url="https://openrouter.ai/api/v1",
)
with patch("backend.copilot.sdk.env.config", cfg):
from backend.copilot.sdk.env import build_sdk_env
result = build_sdk_env()
# Should get subscription result, not OpenRouter
assert result == {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}

View File

@@ -442,8 +442,11 @@ class TestCompactTranscript:
assert result is not None
assert validate_transcript(result)
msgs = _transcript_to_messages(result)
assert len(msgs) == 2
# 3 messages: compressed prefix (2) + preserved last assistant (1)
assert len(msgs) == 3
assert msgs[1]["content"] == "Summarized response"
# The last assistant entry is preserved verbatim from original
assert msgs[2]["content"] == "Details"
@pytest.mark.asyncio
async def test_returns_none_on_compression_failure(self, mock_chat_config):

View File

@@ -15,6 +15,7 @@ from claude_agent_sdk import (
ResultMessage,
SystemMessage,
TextBlock,
ThinkingBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
@@ -100,6 +101,11 @@ class SDKResponseAdapter:
StreamTextDelta(id=self.text_block_id, delta=block.text)
)
elif isinstance(block, ThinkingBlock):
# Thinking blocks are preserved in the transcript but
# not streamed to the frontend — skip silently.
pass
elif isinstance(block, ToolUseBlock):
self._end_text_if_open(responses)

View File

@@ -124,8 +124,11 @@ class TestScenarioCompactAndRetry:
assert result != original # Must be different
assert validate_transcript(result)
msgs = _transcript_to_messages(result)
assert len(msgs) == 2
# 3 messages: compressed prefix (2) + preserved last assistant (1)
assert len(msgs) == 3
assert msgs[0]["content"] == "[summary of conversation]"
# Last assistant preserved verbatim
assert msgs[2]["content"] == "Long answer 2"
def test_compacted_transcript_loads_into_builder(self):
"""TranscriptBuilder can load a compacted transcript and continue."""
@@ -737,7 +740,10 @@ class TestRetryEdgeCases:
assert result is not None
assert result != transcript
msgs = _transcript_to_messages(result)
assert len(msgs) == 2
# 3 messages: compressed prefix (2) + preserved last assistant (1)
assert len(msgs) == 3
# Last assistant preserved verbatim
assert msgs[2]["content"] == "Answer 19"
def test_messages_to_transcript_roundtrip_preserves_content(self):
"""Verify messages → transcript → messages preserves all content."""
@@ -1004,7 +1010,7 @@ def _make_sdk_patches(
(f"{_SVC}.create_security_hooks", dict(return_value=MagicMock())),
(f"{_SVC}.get_copilot_tool_names", dict(return_value=[])),
(f"{_SVC}.get_sdk_disallowed_tools", dict(return_value=[])),
(f"{_SVC}._build_sdk_env", dict(return_value=None)),
(f"{_SVC}.build_sdk_env", dict(return_value=None)),
(f"{_SVC}._resolve_sdk_model", dict(return_value=None)),
(f"{_SVC}.set_execution_context", {}),
(

View File

@@ -77,9 +77,9 @@ from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tracking import track_user_message
from .compaction import CompactionTracker, filter_compaction_messages
from .env import build_sdk_env # noqa: F401 — re-export for backward compat
from .response_adapter import SDKResponseAdapter
from .security_hooks import create_security_hooks
from .subscription import validate_subscription as _validate_claude_code_subscription
from .tool_adapter import (
cancel_pending_tool_tasks,
create_copilot_mcp_server,
@@ -185,6 +185,24 @@ def _is_prompt_too_long(err: BaseException) -> bool:
return False
def _is_sdk_disconnect_error(exc: BaseException) -> bool:
"""Return True if *exc* is an expected SDK cleanup error from client disconnect.
Two known patterns occur when ``GeneratorExit`` tears down the async
generator and the SDK's ``__aexit__`` runs in a different context/task:
* ``RuntimeError``: cancel scope exited in wrong task (anyio)
* ``ValueError``: ContextVar token created in a different Context (OTEL)
These are suppressed to avoid polluting Sentry with non-actionable noise.
"""
if isinstance(exc, RuntimeError) and "cancel scope" in str(exc):
return True
if isinstance(exc, ValueError) and "was created in a different Context" in str(exc):
return True
return False
def _is_tool_only_message(sdk_msg: object) -> bool:
"""Return True if *sdk_msg* is an AssistantMessage containing only ToolUseBlocks.
@@ -409,6 +427,63 @@ _HEARTBEAT_INTERVAL = 10.0 # seconds
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
async def _safe_close_sdk_client(
sdk_client: ClaudeSDKClient,
log_prefix: str,
) -> None:
"""Close a ClaudeSDKClient, suppressing errors from client disconnect.
When the SSE client disconnects mid-stream, ``GeneratorExit`` propagates
through the async generator stack and causes ``ClaudeSDKClient.__aexit__``
to run in a different async context or task than where the client was
opened. This triggers two known error classes:
* ``ValueError``: ``<Token var=<ContextVar name='current_context'>>
was created in a different Context`` — OpenTelemetry's
``context.detach()`` fails because the OTEL context token was
created in the original generator coroutine but detach runs in
the GC / cleanup coroutine (Sentry: AUTOGPT-SERVER-8BT).
* ``RuntimeError``: ``Attempted to exit cancel scope in a different
task than it was entered in`` — anyio's ``TaskGroup.__aexit__``
detects that the cancel scope was entered in one task but is
being exited in another (Sentry: AUTOGPT-SERVER-8BW).
Both are harmless — the TCP connection is already dead and no
resources leak. Logging them at ``debug`` level keeps observability
without polluting Sentry.
"""
try:
await sdk_client.__aexit__(None, None, None)
except (ValueError, RuntimeError) as exc:
if _is_sdk_disconnect_error(exc):
# Expected during client disconnect — suppress to avoid Sentry noise.
logger.debug(
"%s SDK client cleanup error suppressed (client disconnect): %s: %s",
log_prefix,
type(exc).__name__,
exc,
)
else:
raise
except GeneratorExit:
# GeneratorExit can propagate through __aexit__ — suppress it here
# since the generator is already being torn down.
logger.debug(
"%s SDK client cleanup GeneratorExit suppressed (client disconnect)",
log_prefix,
)
except Exception:
# Unexpected cleanup error — log at error level so Sentry captures it
# (via its logging integration), but don't propagate since we're in
# teardown and the caller cannot meaningfully handle this.
logger.error(
"%s Unexpected SDK client cleanup error",
log_prefix,
exc_info=True,
)
async def _iter_sdk_messages(
client: ClaudeSDKClient,
) -> AsyncGenerator[Any, None]:
@@ -492,60 +567,6 @@ def _resolve_sdk_model() -> str | None:
return model
def _build_sdk_env(
session_id: str | None = None,
user_id: str | None = None,
) -> dict[str, str]:
"""Build env vars for the SDK CLI subprocess.
Three modes (checked in order):
1. **Subscription** — clears all keys; CLI uses `claude login` auth.
2. **Direct Anthropic** — returns `{}`; subprocess inherits
`ANTHROPIC_API_KEY` from the parent environment.
3. **OpenRouter** (default) — overrides base URL and auth token to
route through the proxy, with Langfuse trace headers.
"""
# --- Mode 1: Claude Code subscription auth ---
if config.use_claude_code_subscription:
_validate_claude_code_subscription()
return {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}
# --- Mode 2: Direct Anthropic (no proxy hop) ---
# `openrouter_active` checks the flag *and* credential presence.
if not config.openrouter_active:
return {}
# --- Mode 3: OpenRouter proxy ---
# Strip /v1 suffix — SDK expects the base URL without a version path.
base = (config.base_url or "").rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
env: dict[str, str] = {
"ANTHROPIC_BASE_URL": base,
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
}
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
def _safe(v: str) -> str:
"""Sanitise a header value: strip newlines/whitespace and cap length."""
return v.replace("\r", "").replace("\n", "").strip()[:128]
parts = []
if session_id:
parts.append(f"x-session-id: {_safe(session_id)}")
if user_id:
parts.append(f"x-user-id: {_safe(user_id)}")
if parts:
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
return env
def _make_sdk_cwd(session_id: str) -> str:
"""Create a safe, session-specific working directory path.
@@ -595,7 +616,9 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
"""Convert SDK content blocks to transcript format.
Handles TextBlock, ToolUseBlock, ToolResultBlock, and ThinkingBlock.
Unknown block types are logged and skipped.
Raw dicts (e.g. ``redacted_thinking`` blocks that the SDK may not have
a typed class for) are passed through verbatim to preserve them in the
transcript. Unknown typed block objects are logged and skipped.
"""
result: list[dict[str, Any]] = []
for block in blocks or []:
@@ -627,6 +650,9 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
"signature": block.signature,
}
)
elif isinstance(block, dict) and "type" in block:
# Preserve raw dict blocks (e.g. redacted_thinking) verbatim.
result.append(block)
else:
logger.warning(
f"[SDK] Unknown content block type: {type(block).__name__}. "
@@ -1188,7 +1214,17 @@ async def _run_stream_attempt(
consecutive_empty_tool_calls = 0
async with ClaudeSDKClient(options=state.options) as client:
# Use manual __aenter__/__aexit__ instead of ``async with`` so we can
# suppress SDK cleanup errors that occur when the SSE client disconnects
# mid-stream. GeneratorExit causes the SDK's ``__aexit__`` to run in a
# different async context/task than where the client was opened, which
# triggers:
# - ValueError: ContextVar token mismatch (AUTOGPT-SERVER-8BT)
# - RuntimeError: cancel scope in wrong task (AUTOGPT-SERVER-8BW)
# Both are harmless — the TCP connection is already dead.
sdk_client = ClaudeSDKClient(options=state.options)
client = await sdk_client.__aenter__()
try:
logger.info(
"%s Sending query — resume=%s, total_msgs=%d, "
"query_len=%d, attached_files=%d, image_blocks=%d",
@@ -1448,6 +1484,8 @@ async def _run_stream_attempt(
if acc.stream_completed:
break
finally:
await _safe_close_sdk_client(sdk_client, ctx.log_prefix)
# --- Post-stream processing (only on success) ---
if state.adapter.has_unresolved_tool_calls:
@@ -1775,7 +1813,7 @@ async def stream_chat_completion_sdk(
)
# Fail fast when no API credentials are available at all.
sdk_env = _build_sdk_env(session_id=session_id, user_id=user_id)
sdk_env = build_sdk_env(session_id=session_id, user_id=user_id)
if not config.api_key and not config.use_claude_code_subscription:
raise RuntimeError(
"No API key configured. Set OPEN_ROUTER_API_KEY, "
@@ -2151,6 +2189,16 @@ async def stream_chat_completion_sdk(
log_prefix,
len(session.messages),
)
except GeneratorExit:
# GeneratorExit is raised when the async generator is closed by the
# caller (e.g. client disconnect, page refresh). We MUST release the
# stream lock here because the ``finally`` block at the end of this
# function may not execute when GeneratorExit propagates through nested
# async generators. Without this, the lock stays held for its full TTL
# and the user sees "Another stream is already active" on every retry.
logger.warning("%s GeneratorExit — releasing stream lock", log_prefix)
await lock.release()
raise
except BaseException as e:
# Catch BaseException to handle both Exception and CancelledError
# (CancelledError inherits from BaseException in Python 3.8+)
@@ -2159,9 +2207,16 @@ async def stream_chat_completion_sdk(
error_msg = "Operation cancelled"
else:
error_msg = str(e) or type(e).__name__
# SDK cleanup RuntimeError is expected during cancellation, log as warning
if isinstance(e, RuntimeError) and "cancel scope" in str(e):
logger.warning("%s SDK cleanup error: %s", log_prefix, error_msg)
# SDK cleanup errors are expected during client disconnect —
# log as warning rather than error to reduce Sentry noise.
# These are normally caught by _safe_close_sdk_client but
# can escape in edge cases (e.g. GeneratorExit timing).
if _is_sdk_disconnect_error(e):
logger.warning(
"%s SDK cleanup error (client disconnect): %s",
log_prefix,
error_msg,
)
else:
logger.error("%s Error: %s", log_prefix, error_msg, exc_info=True)
@@ -2183,10 +2238,11 @@ async def stream_chat_completion_sdk(
)
# Yield StreamError for immediate feedback (only for non-cancellation errors)
# Skip for CancelledError and RuntimeError cleanup issues (both are cancellations)
is_cancellation = isinstance(e, asyncio.CancelledError) or (
isinstance(e, RuntimeError) and "cancel scope" in str(e)
)
# Skip for CancelledError and SDK disconnect cleanup errors — these
# are not actionable by the user and the SSE connection is already dead.
is_cancellation = isinstance(
e, asyncio.CancelledError
) or _is_sdk_disconnect_error(e)
if not is_cancellation:
yield StreamError(errorText=display_msg, code=code)

View File

@@ -8,7 +8,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from .service import _prepare_file_attachments, _resolve_sdk_model
from .service import (
_is_sdk_disconnect_error,
_prepare_file_attachments,
_resolve_sdk_model,
_safe_close_sdk_client,
)
@dataclass
@@ -499,3 +504,111 @@ class TestResolveSdkModel:
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
assert _resolve_sdk_model() == "claude-opus-4-6"
# ---------------------------------------------------------------------------
# _is_sdk_disconnect_error — classify client disconnect cleanup errors
# ---------------------------------------------------------------------------
class TestIsSdkDisconnectError:
"""Tests for _is_sdk_disconnect_error — identifies expected SDK cleanup errors."""
def test_cancel_scope_runtime_error(self):
"""RuntimeError about cancel scope in wrong task is a disconnect error."""
exc = RuntimeError(
"Attempted to exit cancel scope in a different task than it was entered in"
)
assert _is_sdk_disconnect_error(exc) is True
def test_context_var_value_error(self):
"""ValueError about ContextVar token mismatch is a disconnect error."""
exc = ValueError(
"<Token var=<ContextVar name='current_context'>> "
"was created in a different Context"
)
assert _is_sdk_disconnect_error(exc) is True
def test_unrelated_runtime_error(self):
"""Unrelated RuntimeError should NOT be classified as disconnect error."""
exc = RuntimeError("something else went wrong")
assert _is_sdk_disconnect_error(exc) is False
def test_unrelated_value_error(self):
"""Unrelated ValueError should NOT be classified as disconnect error."""
exc = ValueError("invalid argument")
assert _is_sdk_disconnect_error(exc) is False
def test_other_exception_types(self):
"""Non-RuntimeError/ValueError should NOT be classified as disconnect error."""
assert _is_sdk_disconnect_error(TypeError("bad type")) is False
assert _is_sdk_disconnect_error(OSError("network down")) is False
assert _is_sdk_disconnect_error(asyncio.CancelledError()) is False
# ---------------------------------------------------------------------------
# _safe_close_sdk_client — suppress cleanup errors during disconnect
# ---------------------------------------------------------------------------
class TestSafeCloseSdkClient:
"""Tests for _safe_close_sdk_client — suppresses expected SDK cleanup errors."""
@pytest.mark.asyncio
async def test_clean_exit(self):
"""Normal __aexit__ (no error) should succeed silently."""
client = AsyncMock()
client.__aexit__ = AsyncMock(return_value=None)
await _safe_close_sdk_client(client, "[test]")
client.__aexit__.assert_awaited_once_with(None, None, None)
@pytest.mark.asyncio
async def test_cancel_scope_runtime_error_suppressed(self):
"""RuntimeError from cancel scope mismatch should be suppressed."""
client = AsyncMock()
client.__aexit__ = AsyncMock(
side_effect=RuntimeError(
"Attempted to exit cancel scope in a different task"
)
)
# Should NOT raise
await _safe_close_sdk_client(client, "[test]")
@pytest.mark.asyncio
async def test_context_var_value_error_suppressed(self):
"""ValueError from ContextVar token mismatch should be suppressed."""
client = AsyncMock()
client.__aexit__ = AsyncMock(
side_effect=ValueError(
"<Token var=<ContextVar name='current_context'>> "
"was created in a different Context"
)
)
# Should NOT raise
await _safe_close_sdk_client(client, "[test]")
@pytest.mark.asyncio
async def test_unexpected_exception_suppressed_with_error_log(self):
"""Unexpected exceptions should be caught (not propagated) but logged at error."""
client = AsyncMock()
client.__aexit__ = AsyncMock(side_effect=OSError("unexpected"))
# Should NOT raise — unexpected errors are also suppressed to
# avoid crashing the generator during teardown. Logged at error
# level so Sentry captures them via its logging integration.
await _safe_close_sdk_client(client, "[test]")
@pytest.mark.asyncio
async def test_unrelated_runtime_error_propagates(self):
"""Non-cancel-scope RuntimeError should propagate (not suppressed)."""
client = AsyncMock()
client.__aexit__ = AsyncMock(side_effect=RuntimeError("something unrelated"))
with pytest.raises(RuntimeError, match="something unrelated"):
await _safe_close_sdk_client(client, "[test]")
@pytest.mark.asyncio
async def test_unrelated_value_error_propagates(self):
"""Non-disconnect ValueError should propagate (not suppressed)."""
client = AsyncMock()
client.__aexit__ = AsyncMock(side_effect=ValueError("invalid argument"))
with pytest.raises(ValueError, match="invalid argument"):
await _safe_close_sdk_client(client, "[test]")

View File

@@ -0,0 +1,822 @@
"""Tests for thinking/redacted_thinking block preservation.
Validates the fix for the Anthropic API error:
"thinking or redacted_thinking blocks in the latest assistant message
cannot be modified. These blocks must remain as they were in the
original response."
The API requires that thinking blocks in the LAST assistant message are
preserved value-identical. Older assistant messages may have thinking blocks
stripped entirely. This test suite covers:
1. _flatten_assistant_content — strips thinking from older messages
2. compact_transcript — preserves last assistant's thinking blocks
3. response_adapter — handles ThinkingBlock without error
4. _format_sdk_content_blocks — preserves redacted_thinking blocks
"""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from claude_agent_sdk import AssistantMessage, TextBlock, ThinkingBlock
from backend.copilot.response_model import (
StreamStartStep,
StreamTextDelta,
StreamTextStart,
)
from backend.util import json
from .conftest import build_structured_transcript
from .response_adapter import SDKResponseAdapter
from .service import _format_sdk_content_blocks
from .transcript import (
_find_last_assistant_entry,
_flatten_assistant_content,
_messages_to_transcript,
_rechain_tail,
_transcript_to_messages,
compact_transcript,
validate_transcript,
)
# ---------------------------------------------------------------------------
# Fixtures: realistic thinking block content
# ---------------------------------------------------------------------------
THINKING_BLOCK = {
"type": "thinking",
"thinking": "Let me analyze the user's request carefully...",
"signature": "ErUBCkYIAxgCIkD0V2MsRXPkuGolGexaW9V1kluijxXGF",
}
REDACTED_THINKING_BLOCK = {
"type": "redacted_thinking",
"data": "EmwKAhgBEgy2VEE8PJaS2oLJCPkaT...",
}
def _make_thinking_transcript() -> str:
"""Build a transcript with thinking blocks in multiple assistant turns.
Layout:
User 1 → Assistant 1 (thinking + text + tool_use)
User 2 (tool_result) → Assistant 2 (thinking + text)
User 3 → Assistant 3 (thinking + redacted_thinking + text) ← LAST
"""
return build_structured_transcript(
[
("user", "What files are in this project?"),
(
"assistant",
[
{
"type": "thinking",
"thinking": "I should list the files.",
"signature": "sig_old_1",
},
{"type": "text", "text": "Let me check the files."},
{
"type": "tool_use",
"id": "tu1",
"name": "list_files",
"input": {"path": "/"},
},
],
),
("user", "Here are the files: a.py, b.py"),
(
"assistant",
[
{
"type": "thinking",
"thinking": "Good, I see two Python files.",
"signature": "sig_old_2",
},
{"type": "text", "text": "I found a.py and b.py."},
],
),
("user", "Tell me about a.py"),
(
"assistant",
[
THINKING_BLOCK,
REDACTED_THINKING_BLOCK,
{"type": "text", "text": "a.py contains the main entry point."},
],
),
]
)
def _last_assistant_content(transcript_jsonl: str) -> list[dict] | None:
"""Extract the content blocks of the last assistant entry in a transcript."""
last_content = None
for line in transcript_jsonl.strip().split("\n"):
entry = json.loads(line)
msg = entry.get("message", {})
if msg.get("role") == "assistant":
last_content = msg.get("content")
return last_content
# ---------------------------------------------------------------------------
# _find_last_assistant_entry — unit tests
# ---------------------------------------------------------------------------
class TestFindLastAssistantEntry:
def test_splits_at_last_assistant(self):
"""Prefix contains everything before last assistant; tail starts at it."""
transcript = build_structured_transcript(
[
("user", "Hello"),
("assistant", [{"type": "text", "text": "Hi"}]),
("user", "More"),
("assistant", [{"type": "text", "text": "Details"}]),
]
)
prefix, tail = _find_last_assistant_entry(transcript)
# 3 entries in prefix (user, assistant, user), 1 in tail (last assistant)
assert len(prefix) == 3
assert len(tail) == 1
def test_no_assistant_returns_all_in_prefix(self):
"""When there's no assistant, all lines are in prefix, tail is empty."""
transcript = build_structured_transcript(
[("user", "Hello"), ("user", "Another question")]
)
prefix, tail = _find_last_assistant_entry(transcript)
assert len(prefix) == 2
assert tail == []
def test_assistant_at_index_zero(self):
"""When assistant is the first entry, prefix is empty."""
transcript = build_structured_transcript(
[("assistant", [{"type": "text", "text": "Start"}])]
)
prefix, tail = _find_last_assistant_entry(transcript)
assert prefix == []
assert len(tail) == 1
def test_trailing_user_included_in_tail(self):
"""User message after last assistant is part of the tail."""
transcript = build_structured_transcript(
[
("user", "Q1"),
("assistant", [{"type": "text", "text": "A1"}]),
("user", "Q2"),
]
)
prefix, tail = _find_last_assistant_entry(transcript)
assert len(prefix) == 1 # first user
assert len(tail) == 2 # last assistant + trailing user
def test_multi_entry_turn_fully_preserved(self):
"""An assistant turn spanning multiple JSONL entries (same message.id)
must be entirely in the tail, not split across prefix and tail."""
# Build manually because build_structured_transcript generates unique ids
lines = [
json.dumps(
{
"type": "user",
"uuid": "u1",
"parentUuid": "",
"message": {"role": "user", "content": "Hello"},
}
),
json.dumps(
{
"type": "assistant",
"uuid": "a1-think",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_same_turn",
"type": "message",
"content": [THINKING_BLOCK],
"stop_reason": None,
"stop_sequence": None,
},
}
),
json.dumps(
{
"type": "assistant",
"uuid": "a1-tool",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_same_turn",
"type": "message",
"content": [
{
"type": "tool_use",
"id": "tu1",
"name": "Bash",
"input": {},
},
],
"stop_reason": "tool_use",
"stop_sequence": None,
},
}
),
]
transcript = "\n".join(lines) + "\n"
prefix, tail = _find_last_assistant_entry(transcript)
# Both assistant entries share msg_same_turn → both in tail
assert len(prefix) == 1 # only the user entry
assert len(tail) == 2 # both assistant entries (thinking + tool_use)
def test_no_message_id_preserves_last_assistant(self):
"""When the last assistant entry has no message.id, it should still
be preserved in the tail (fail closed) rather than being compressed."""
lines = [
json.dumps(
{
"type": "user",
"uuid": "u1",
"parentUuid": "",
"message": {"role": "user", "content": "Hello"},
}
),
json.dumps(
{
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {
"role": "assistant",
"content": [THINKING_BLOCK, {"type": "text", "text": "Hi"}],
},
}
),
]
transcript = "\n".join(lines) + "\n"
prefix, tail = _find_last_assistant_entry(transcript)
assert len(prefix) == 1 # user entry
assert len(tail) == 1 # assistant entry preserved
# ---------------------------------------------------------------------------
# _rechain_tail — UUID chain patching
# ---------------------------------------------------------------------------
class TestRechainTail:
def test_patches_first_entry_parentuuid(self):
"""First tail entry's parentUuid should point to last prefix uuid."""
prefix = _messages_to_transcript(
[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
]
)
# Get the last uuid from the prefix
last_prefix_uuid = None
for line in prefix.strip().split("\n"):
entry = json.loads(line)
last_prefix_uuid = entry.get("uuid")
tail_lines = [
json.dumps(
{
"type": "assistant",
"uuid": "tail-a1",
"parentUuid": "old-parent",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "Tail msg"}],
},
}
)
]
result = _rechain_tail(prefix, tail_lines)
entry = json.loads(result.strip())
assert entry["parentUuid"] == last_prefix_uuid
assert entry["uuid"] == "tail-a1" # uuid preserved
def test_chains_multiple_tail_entries(self):
"""Subsequent tail entries chain to each other."""
prefix = _messages_to_transcript([{"role": "user", "content": "Hi"}])
tail_lines = [
json.dumps(
{
"type": "assistant",
"uuid": "t1",
"parentUuid": "old1",
"message": {"role": "assistant", "content": []},
}
),
json.dumps(
{
"type": "user",
"uuid": "t2",
"parentUuid": "old2",
"message": {"role": "user", "content": "Follow-up"},
}
),
]
result = _rechain_tail(prefix, tail_lines)
entries = [json.loads(ln) for ln in result.strip().split("\n")]
assert len(entries) == 2
# Second entry's parentUuid should be first entry's uuid
assert entries[1]["parentUuid"] == "t1"
def test_empty_tail_returns_empty(self):
"""No tail entries → empty string."""
prefix = _messages_to_transcript([{"role": "user", "content": "Hi"}])
assert _rechain_tail(prefix, []) == ""
def test_preserves_message_content_verbatim(self):
"""Tail message content (including thinking blocks) must not be modified."""
prefix = _messages_to_transcript([{"role": "user", "content": "Hi"}])
original_content = [
THINKING_BLOCK,
REDACTED_THINKING_BLOCK,
{"type": "text", "text": "Response"},
]
tail_lines = [
json.dumps(
{
"type": "assistant",
"uuid": "t1",
"parentUuid": "old",
"message": {
"role": "assistant",
"content": original_content,
},
}
)
]
result = _rechain_tail(prefix, tail_lines)
entry = json.loads(result.strip())
assert entry["message"]["content"] == original_content
# ---------------------------------------------------------------------------
# _flatten_assistant_content — thinking blocks
# ---------------------------------------------------------------------------
class TestFlattenThinkingBlocks:
def test_thinking_blocks_are_stripped(self):
"""Thinking blocks should not appear in flattened text for compression."""
blocks = [
{"type": "thinking", "thinking": "secret thoughts", "signature": "sig"},
{"type": "text", "text": "Hello user"},
]
result = _flatten_assistant_content(blocks)
assert "secret thoughts" not in result
assert "Hello user" in result
def test_redacted_thinking_blocks_are_stripped(self):
"""Redacted thinking blocks should not appear in flattened text."""
blocks = [
{"type": "redacted_thinking", "data": "encrypted_data"},
{"type": "text", "text": "Response text"},
]
result = _flatten_assistant_content(blocks)
assert "encrypted_data" not in result
assert "Response text" in result
def test_thinking_only_message_flattens_to_empty(self):
"""A message with only thinking blocks flattens to empty string."""
blocks = [
{"type": "thinking", "thinking": "just thinking...", "signature": "sig"},
]
result = _flatten_assistant_content(blocks)
assert result == ""
def test_mixed_thinking_text_tool(self):
"""Mixed blocks: only text and tool_use survive flattening."""
blocks = [
{"type": "thinking", "thinking": "hmm", "signature": "sig"},
{"type": "redacted_thinking", "data": "xyz"},
{"type": "text", "text": "I'll read the file."},
{"type": "tool_use", "name": "Read", "input": {"path": "/x"}},
]
result = _flatten_assistant_content(blocks)
assert "hmm" not in result
assert "xyz" not in result
assert "I'll read the file." in result
assert "[tool_use: Read]" in result
# ---------------------------------------------------------------------------
# compact_transcript — thinking block preservation
# ---------------------------------------------------------------------------
class TestCompactTranscriptThinkingBlocks:
"""Verify that compact_transcript preserves thinking blocks in the
last assistant message while stripping them from older messages."""
@pytest.mark.asyncio
async def test_last_assistant_thinking_blocks_preserved(self, mock_chat_config):
"""After compaction, the last assistant entry must retain its
original thinking and redacted_thinking blocks verbatim."""
transcript = _make_thinking_transcript()
compacted_msgs = [
{"role": "user", "content": "[conversation summary]"},
{"role": "assistant", "content": "Summarized response"},
]
mock_result = type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": compacted_msgs,
"original_token_count": 800,
"token_count": 200,
"messages_summarized": 4,
"messages_dropped": 0,
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
result = await compact_transcript(transcript, model="test-model")
assert result is not None
assert validate_transcript(result)
last_content = _last_assistant_content(result)
assert last_content is not None, "No assistant entry found"
assert isinstance(last_content, list)
# The last assistant must have the thinking blocks preserved
block_types = [b["type"] for b in last_content]
assert (
"thinking" in block_types
), "thinking block missing from last assistant message"
assert (
"redacted_thinking" in block_types
), "redacted_thinking block missing from last assistant message"
assert "text" in block_types
# Verify the thinking block content is value-identical
thinking_blocks = [b for b in last_content if b["type"] == "thinking"]
assert len(thinking_blocks) == 1
assert thinking_blocks[0]["thinking"] == THINKING_BLOCK["thinking"]
assert thinking_blocks[0]["signature"] == THINKING_BLOCK["signature"]
redacted_blocks = [b for b in last_content if b["type"] == "redacted_thinking"]
assert len(redacted_blocks) == 1
assert redacted_blocks[0]["data"] == REDACTED_THINKING_BLOCK["data"]
@pytest.mark.asyncio
async def test_older_assistant_thinking_blocks_stripped(self, mock_chat_config):
"""Older assistant messages should NOT retain thinking blocks
after compaction (they're compressed into summaries)."""
transcript = _make_thinking_transcript()
# The compressor will receive messages where older assistant
# entries have already had thinking blocks stripped.
captured_messages: list[dict] = []
async def mock_compression(messages, model, log_prefix):
captured_messages.extend(messages)
return type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": messages,
"original_token_count": 800,
"token_count": 400,
"messages_summarized": 2,
"messages_dropped": 0,
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
side_effect=mock_compression,
):
await compact_transcript(transcript, model="test-model")
# Check that the messages sent to compression don't contain
# thinking content from older assistant messages
for msg in captured_messages:
if msg["role"] == "assistant":
content = msg.get("content", "")
assert (
"I should list the files." not in content
), "Old thinking block content leaked into compression input"
assert (
"Good, I see two Python files." not in content
), "Old thinking block content leaked into compression input"
@pytest.mark.asyncio
async def test_trailing_user_message_after_last_assistant(self, mock_chat_config):
"""When the last entry is a user message, the last *assistant*
message's thinking blocks should still be preserved."""
transcript = build_structured_transcript(
[
("user", "Hello"),
(
"assistant",
[
THINKING_BLOCK,
{"type": "text", "text": "Hi there"},
],
),
("user", "Follow-up question"),
]
)
# The compressor only receives the prefix (1 user message); the
# tail (assistant + trailing user) is preserved verbatim.
compacted_msgs = [
{"role": "user", "content": "Hello"},
]
mock_result = type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": compacted_msgs,
"original_token_count": 400,
"token_count": 100,
"messages_summarized": 0,
"messages_dropped": 0,
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
result = await compact_transcript(transcript, model="test-model")
assert result is not None
last_content = _last_assistant_content(result)
assert last_content is not None
assert isinstance(last_content, list)
block_types = [b["type"] for b in last_content]
assert (
"thinking" in block_types
), "thinking block lost from last assistant despite trailing user msg"
@pytest.mark.asyncio
async def test_single_assistant_with_thinking_preserved(self, mock_chat_config):
"""When there's only one assistant message (which is also the last),
its thinking blocks must be preserved."""
transcript = build_structured_transcript(
[
("user", "Hello"),
(
"assistant",
[
THINKING_BLOCK,
{"type": "text", "text": "World"},
],
),
]
)
compacted_msgs = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "World"},
]
mock_result = type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": compacted_msgs,
"original_token_count": 200,
"token_count": 100,
"messages_summarized": 0,
"messages_dropped": 0,
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
result = await compact_transcript(transcript, model="test-model")
assert result is not None
last_content = _last_assistant_content(result)
assert last_content is not None
assert isinstance(last_content, list)
block_types = [b["type"] for b in last_content]
assert "thinking" in block_types
@pytest.mark.asyncio
async def test_tail_parentuuid_rewired_to_prefix(self, mock_chat_config):
"""After compaction, the first tail entry's parentUuid must point to
the last entry in the compressed prefix — not its original parent."""
transcript = _make_thinking_transcript()
compacted_msgs = [
{"role": "user", "content": "[conversation summary]"},
{"role": "assistant", "content": "Summarized response"},
]
mock_result = type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": compacted_msgs,
"original_token_count": 800,
"token_count": 200,
"messages_summarized": 4,
"messages_dropped": 0,
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
result = await compact_transcript(transcript, model="test-model")
assert result is not None
lines = [ln for ln in result.strip().split("\n") if ln.strip()]
entries = [json.loads(ln) for ln in lines]
# Find the boundary: the compressed prefix ends just before the
# first tail entry (last assistant in original transcript).
tail_start = None
for i, entry in enumerate(entries):
msg = entry.get("message", {})
if isinstance(msg.get("content"), list):
# Structured content = preserved tail entry
tail_start = i
break
assert tail_start is not None, "Could not find preserved tail entry"
assert tail_start > 0, "Tail should not be the first entry"
# The tail entry's parentUuid must be the uuid of the preceding entry
prefix_last_uuid = entries[tail_start - 1]["uuid"]
tail_first_parent = entries[tail_start]["parentUuid"]
assert tail_first_parent == prefix_last_uuid, (
f"Tail parentUuid {tail_first_parent!r} != "
f"last prefix uuid {prefix_last_uuid!r}"
)
@pytest.mark.asyncio
async def test_no_thinking_blocks_still_works(self, mock_chat_config):
"""Compaction should still work normally when there are no thinking
blocks in the transcript."""
transcript = build_structured_transcript(
[
("user", "Hello"),
("assistant", [{"type": "text", "text": "Hi"}]),
("user", "More"),
("assistant", [{"type": "text", "text": "Details"}]),
]
)
compacted_msgs = [
{"role": "user", "content": "[summary]"},
{"role": "assistant", "content": "Summary"},
]
mock_result = type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": compacted_msgs,
"original_token_count": 200,
"token_count": 50,
"messages_summarized": 2,
"messages_dropped": 0,
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
result = await compact_transcript(transcript, model="test-model")
assert result is not None
assert validate_transcript(result)
# Verify last assistant content is preserved even without thinking blocks
last_content = _last_assistant_content(result)
assert last_content is not None
assert last_content == [{"type": "text", "text": "Details"}]
# ---------------------------------------------------------------------------
# _transcript_to_messages — thinking block handling
# ---------------------------------------------------------------------------
class TestTranscriptToMessagesThinking:
def test_thinking_blocks_excluded_from_flattened_content(self):
"""When _transcript_to_messages flattens content, thinking block
text should not leak into the message content string."""
transcript = build_structured_transcript(
[
("user", "Hello"),
(
"assistant",
[
{
"type": "thinking",
"thinking": "SECRET_THOUGHT",
"signature": "sig",
},
{"type": "text", "text": "Visible response"},
],
),
]
)
messages = _transcript_to_messages(transcript)
assistant_msg = [m for m in messages if m["role"] == "assistant"][0]
assert "SECRET_THOUGHT" not in assistant_msg["content"]
assert "Visible response" in assistant_msg["content"]
# ---------------------------------------------------------------------------
# response_adapter — ThinkingBlock handling
# ---------------------------------------------------------------------------
class TestResponseAdapterThinkingBlock:
def test_thinking_block_does_not_crash(self):
"""ThinkingBlock in AssistantMessage should not cause an error."""
adapter = SDKResponseAdapter(message_id="msg-1", session_id="sess-1")
msg = AssistantMessage(
content=[
ThinkingBlock(
thinking="Let me think about this...",
signature="sig_test_123",
),
TextBlock(text="Here is my response."),
],
model="claude-test",
)
results = adapter.convert_message(msg)
# Should produce stream events for text only, no crash
types = [type(r) for r in results]
assert StreamStartStep in types
assert StreamTextStart in types or StreamTextDelta in types
def test_thinking_block_does_not_emit_stream_events(self):
"""ThinkingBlock should NOT produce any StreamTextDelta events
containing thinking content."""
adapter = SDKResponseAdapter(message_id="msg-1", session_id="sess-1")
msg = AssistantMessage(
content=[
ThinkingBlock(
thinking="My secret thoughts",
signature="sig_test_456",
),
TextBlock(text="Public response"),
],
model="claude-test",
)
results = adapter.convert_message(msg)
text_deltas = [r for r in results if isinstance(r, StreamTextDelta)]
for delta in text_deltas:
assert "secret thoughts" not in (delta.delta or "")
# ---------------------------------------------------------------------------
# _format_sdk_content_blocks — redacted_thinking handling
# ---------------------------------------------------------------------------
class TestFormatSdkContentBlocks:
def test_thinking_block_preserved(self):
"""ThinkingBlock should be serialized with type, thinking, and signature."""
blocks = [
ThinkingBlock(thinking="My thoughts", signature="sig123"),
TextBlock(text="Response"),
]
result = _format_sdk_content_blocks(blocks)
assert len(result) == 2
assert result[0] == {
"type": "thinking",
"thinking": "My thoughts",
"signature": "sig123",
}
assert result[1] == {"type": "text", "text": "Response"}
def test_raw_dict_redacted_thinking_preserved(self):
"""Raw dict blocks (e.g. redacted_thinking) pass through unchanged."""
raw_block = {"type": "redacted_thinking", "data": "EmwKAh...encrypted"}
blocks = [
raw_block,
TextBlock(text="Response"),
]
result = _format_sdk_content_blocks(blocks)
assert len(result) == 2
assert result[0] == raw_block
assert result[1] == {"type": "text", "text": "Response"}

View File

@@ -605,20 +605,31 @@ COMPACT_MSG_ID_PREFIX = "msg_compact_"
ENTRY_TYPE_MESSAGE = "message"
_THINKING_BLOCK_TYPES = frozenset({"thinking", "redacted_thinking"})
def _flatten_assistant_content(blocks: list) -> str:
"""Flatten assistant content blocks into a single plain-text string.
Structured ``tool_use`` blocks are converted to ``[tool_use: name]``
placeholders. This is intentional: ``compress_context`` requires plain
text for token counting and LLM summarization. The structural loss is
acceptable because compaction only runs when the original transcript was
already too large for the model — a summarized plain-text version is
better than no context at all.
placeholders. ``thinking`` and ``redacted_thinking`` blocks are
silently dropped — they carry no useful context for compression
summaries and must not leak into compacted transcripts (the Anthropic
API requires thinking blocks in the last assistant message to be
value-identical to the original response; including stale thinking
text would violate that constraint).
This is intentional: ``compress_context`` requires plain text for
token counting and LLM summarization. The structural loss is
acceptable because compaction only runs when the original transcript
was already too large for the model.
"""
parts: list[str] = []
for block in blocks:
if isinstance(block, dict):
btype = block.get("type", "")
if btype in _THINKING_BLOCK_TYPES:
continue
if btype == "text":
parts.append(block.get("text", ""))
elif btype == "tool_use":
@@ -805,6 +816,68 @@ async def _run_compression(
)
def _find_last_assistant_entry(
content: str,
) -> tuple[list[str], list[str]]:
"""Split JSONL lines into (compressible_prefix, preserved_tail).
The tail starts at the **first** entry of the last assistant turn and
includes everything after it (typically trailing user messages). An
assistant turn can span multiple consecutive JSONL entries sharing the
same ``message.id`` (e.g., a thinking entry followed by a tool_use
entry). All entries of the turn are preserved verbatim.
The Anthropic API requires that ``thinking`` and ``redacted_thinking``
blocks in the **last** assistant message remain value-identical to the
original response (the API validates parsed signature values, not raw
JSON bytes). By excluding the entire turn from compression we
guarantee those blocks are never altered.
Returns ``(all_lines, [])`` when no assistant entry is found.
"""
lines = [ln for ln in content.strip().split("\n") if ln.strip()]
# Parse all lines once to avoid double JSON deserialization.
# json.loads with fallback=None returns Any; non-dict entries are
# safely skipped by the isinstance(entry, dict) guards below.
parsed: list = [json.loads(ln, fallback=None) for ln in lines]
# Reverse scan: find the message.id and index of the last assistant entry.
last_asst_msg_id: str | None = None
last_asst_idx: int | None = None
for i in range(len(parsed) - 1, -1, -1):
entry = parsed[i]
if not isinstance(entry, dict):
continue
msg = entry.get("message", {})
if msg.get("role") == "assistant":
last_asst_idx = i
last_asst_msg_id = msg.get("id")
break
if last_asst_idx is None:
return lines, []
# If the assistant entry has no message.id, fall back to preserving
# from that single entry onward — safer than compressing everything.
if last_asst_msg_id is None:
return lines[:last_asst_idx], lines[last_asst_idx:]
# Forward scan: find the first entry of this turn (same message.id).
first_turn_idx: int | None = None
for i, entry in enumerate(parsed):
if not isinstance(entry, dict):
continue
msg = entry.get("message", {})
if msg.get("role") == "assistant" and msg.get("id") == last_asst_msg_id:
first_turn_idx = i
break
if first_turn_idx is None:
return lines, []
return lines[:first_turn_idx], lines[first_turn_idx:]
async def compact_transcript(
content: str,
*,
@@ -816,42 +889,50 @@ async def compact_transcript(
Converts transcript entries to plain messages, runs ``compress_context``
(the same compressor used for pre-query history), and rebuilds JSONL.
Structured content (``tool_use`` blocks, ``tool_result`` nesting, images)
is flattened to plain text for compression. This matches the fidelity of
the Plan C (DB compression) fallback path, where
``_format_conversation_context`` similarly renders tool calls as
``You called tool: name(args)`` and results as ``Tool result: ...``.
Neither path preserves structured API content blocks — the compacted
context serves as text history for the LLM, which creates proper
structured tool calls going forward.
The **last assistant entry** (and any entries after it) are preserved
verbatim — never flattened or compressed. The Anthropic API requires
``thinking`` and ``redacted_thinking`` blocks in the latest assistant
message to be value-identical to the original response (the API
validates parsed signature values, not raw JSON bytes); compressing
them would destroy the cryptographic signatures and cause
``invalid_request_error``.
Images are per-turn attachments loaded from workspace storage by file ID
(via ``_prepare_file_attachments``), not part of the conversation history.
They are re-attached each turn and are unaffected by compaction.
Structured content in *older* assistant entries (``tool_use`` blocks,
``thinking`` blocks, ``tool_result`` nesting, images) is flattened to
plain text for compression. This matches the fidelity of the Plan C
(DB compression) fallback path.
Returns the compacted JSONL string, or ``None`` on failure.
See also:
``_compress_messages`` in ``service.py`` — compresses ``ChatMessage``
lists for pre-query DB history. Both share ``compress_context()``
but operate on different input formats (JSONL transcript entries
here vs. ChatMessage dicts there).
lists for pre-query DB history.
"""
messages = _transcript_to_messages(content)
if len(messages) < 2:
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
prefix_lines, tail_lines = _find_last_assistant_entry(content)
# Build the JSONL string for the compressible prefix
prefix_content = "\n".join(prefix_lines) + "\n" if prefix_lines else ""
messages = _transcript_to_messages(prefix_content) if prefix_content else []
if len(messages) + len(tail_lines) < 2:
total = len(messages) + len(tail_lines)
logger.warning("%s Too few messages to compact (%d)", log_prefix, total)
return None
if not messages:
logger.warning("%s Nothing to compress (only tail entries remain)", log_prefix)
return None
try:
result = await _run_compression(messages, model, log_prefix)
if not result.was_compacted:
# Compressor says it's within budget, but the SDK rejected it.
# Return None so the caller falls through to DB fallback.
logger.warning(
"%s Compressor reports within budget but SDK rejected — "
"signalling failure",
log_prefix,
)
return None
if not result.messages:
logger.warning("%s Compressor returned empty messages", log_prefix)
return None
logger.info(
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
log_prefix,
@@ -860,7 +941,29 @@ async def compact_transcript(
result.messages_summarized,
result.messages_dropped,
)
compacted = _messages_to_transcript(result.messages)
compressed_part = _messages_to_transcript(result.messages)
# Re-append the preserved tail (last assistant + trailing entries)
# with parentUuid patched to chain onto the compressed prefix.
tail_part = _rechain_tail(compressed_part, tail_lines)
compacted = compressed_part + tail_part
if len(compacted) >= len(content):
# Byte count can increase due to preserved tail entries
# (thinking blocks, JSON overhead) even when token count
# decreased. Log a warning but still return — the API
# validates tokens not bytes, and the caller falls through
# to DB fallback if the transcript is still too large.
logger.warning(
"%s Compacted transcript (%d bytes) is not smaller than "
"original (%d bytes) — may still reduce token count",
log_prefix,
len(compacted),
len(content),
)
# Authoritative validation — the caller (_reduce_context) also
# validates, but this is the canonical check that guarantees we
# never return a malformed transcript from this function.
if not validate_transcript(compacted):
logger.warning("%s Compacted transcript failed validation", log_prefix)
return None
@@ -870,3 +973,43 @@ async def compact_transcript(
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
)
return None
def _rechain_tail(compressed_prefix: str, tail_lines: list[str]) -> str:
"""Patch tail entries so their parentUuid chain links to the compressed prefix.
The first tail entry's ``parentUuid`` is set to the ``uuid`` of the
last entry in the compressed prefix. Subsequent tail entries are
rechained to point to their predecessor in the tail — their original
``parentUuid`` values may reference entries that were compressed away.
"""
if not tail_lines:
return ""
# Find the last uuid in the compressed prefix
last_prefix_uuid = ""
for line in reversed(compressed_prefix.strip().split("\n")):
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if isinstance(entry, dict) and "uuid" in entry:
last_prefix_uuid = entry["uuid"]
break
result_lines: list[str] = []
prev_uuid: str | None = None
for i, line in enumerate(tail_lines):
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
# Safety guard: _find_last_assistant_entry already filters empty
# lines, and well-formed JSONL always parses to dicts. Non-dict
# lines are passed through unchanged; prev_uuid is intentionally
# NOT updated so the next dict entry chains to the last known uuid.
result_lines.append(line)
continue
if i == 0:
entry["parentUuid"] = last_prefix_uuid
elif prev_uuid is not None:
entry["parentUuid"] = prev_uuid
prev_uuid = entry.get("uuid")
result_lines.append(json.dumps(entry, separators=(",", ":")))
return "\n".join(result_lines) + "\n"

View File

@@ -17,13 +17,16 @@ Subscribers:
import asyncio
import logging
import time
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Literal
import orjson
from redis.exceptions import RedisError
from backend.api.model import CopilotCompletionPayload
from backend.data.db_accessors import chat_db
from backend.data.notification_bus import (
AsyncRedisNotificationEventBus,
NotificationEvent,
@@ -33,12 +36,21 @@ from backend.data.redis_client import get_redis_async
from .config import ChatConfig
from .executor.utils import COPILOT_CONSUMER_TIMEOUT_SECONDS
from .response_model import (
ResponseType,
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
logger = logging.getLogger(__name__)
@@ -100,6 +112,14 @@ def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSes
``session_id`` is used as a fallback for ``turn_id`` when the meta hash
pre-dates the turn_id field (backward compat for in-flight sessions).
"""
created_at = datetime.now(timezone.utc)
created_at_raw = meta.get("created_at")
if created_at_raw:
try:
created_at = datetime.fromisoformat(str(created_at_raw))
except (ValueError, TypeError):
pass
return ActiveSession(
session_id=meta.get("session_id", "") or session_id,
user_id=meta.get("user_id", "") or None,
@@ -108,6 +128,7 @@ def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSes
turn_id=meta.get("turn_id", "") or session_id,
blocking=meta.get("blocking") == "1",
status=meta.get("status", "running"), # type: ignore[arg-type]
created_at=created_at,
)
@@ -280,6 +301,56 @@ async def publish_chunk(
return message_id
async def stream_and_publish(
session_id: str,
turn_id: str,
stream: AsyncIterator[StreamBaseResponse],
) -> AsyncIterator[StreamBaseResponse]:
"""Wrap an async stream iterator with registry publishing.
Publishes each chunk to the stream registry for frontend SSE consumption,
skipping ``StreamFinish`` and ``StreamError`` (which are published by
:func:`mark_session_completed`).
This is a pass-through: every event from *stream* is yielded unchanged so
the caller can still consume and aggregate them. The caller is responsible
for calling :func:`create_session` before and :func:`mark_session_completed`
after iterating.
Args:
session_id: Chat session ID (for logging only).
turn_id: Turn UUID that identifies the Redis stream to publish to.
If empty, publishing is silently skipped (graceful degradation).
stream: The underlying async iterator of stream events.
Yields:
Every event from *stream*, unchanged.
"""
publish_failed_once = False
async for event in stream:
if turn_id and not isinstance(event, (StreamFinish, StreamError)):
try:
await publish_chunk(turn_id, event)
except (RedisError, ConnectionError, OSError):
if not publish_failed_once:
publish_failed_once = True
logger.warning(
"[stream_and_publish] Failed to publish chunk %s for %s "
"(further failures logged at DEBUG)",
type(event).__name__,
session_id[:12],
exc_info=True,
)
else:
logger.debug(
"[stream_and_publish] Failed to publish chunk %s",
type(event).__name__,
exc_info=True,
)
yield event
async def subscribe_to_session(
session_id: str,
user_id: str | None,
@@ -693,6 +764,8 @@ async def _stream_listener(
async def mark_session_completed(
session_id: str,
error_message: str | None = None,
*,
skip_error_publish: bool = False,
) -> bool:
"""Mark a session as completed, then publish StreamFinish.
@@ -708,6 +781,10 @@ async def mark_session_completed(
session_id: Session ID to mark as completed
error_message: If provided, marks as "failed" and publishes a
StreamError before StreamFinish. Otherwise marks as "completed".
skip_error_publish: If True, still marks the session as "failed" but
does NOT publish a StreamError event. Use this when the error has
already been published to the stream (e.g. via stream_and_publish)
to avoid duplicate error delivery to the frontend.
Returns:
True if session was newly marked completed, False if already completed/failed
@@ -727,7 +804,7 @@ async def mark_session_completed(
logger.debug(f"Session {session_id} already completed/failed, skipping")
return False
if error_message:
if error_message and not skip_error_publish:
try:
await publish_chunk(turn_id, StreamError(errorText=error_message))
except Exception as e:
@@ -735,6 +812,33 @@ async def mark_session_completed(
f"Failed to publish error event for session {session_id}: {e}"
)
# Compute wall-clock duration from session created_at.
# Only persist when (a) the session completed successfully and
# (b) created_at was actually present in Redis meta (not a fallback).
duration_ms: int | None = None
if meta and not error_message:
created_at_raw = meta.get("created_at")
if created_at_raw:
try:
created_at = datetime.fromisoformat(str(created_at_raw))
if created_at.tzinfo is None:
created_at = created_at.replace(tzinfo=timezone.utc)
elapsed = datetime.now(timezone.utc) - created_at
duration_ms = max(0, int(elapsed.total_seconds() * 1000))
except (ValueError, TypeError):
logger.warning(
"Failed to compute session duration for %s (created_at=%r)",
session_id,
created_at_raw,
)
# Persist duration on the last assistant message
if duration_ms is not None:
try:
await chat_db().set_turn_duration(session_id, duration_ms)
except Exception as e:
logger.warning(f"Failed to save turn duration for {session_id}: {e}")
# Publish StreamFinish AFTER status is set to "completed"/"failed".
# This is the SINGLE place that publishes StreamFinish — services and
# the processor must NOT publish it themselves.
@@ -913,21 +1017,6 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
Returns:
Reconstructed response object, or None if unknown type
"""
from .response_model import (
ResponseType,
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamStartStep,
StreamTextEnd,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
# Map response types to their corresponding classes
type_to_class: dict[str, type[StreamBaseResponse]] = {
ResponseType.START.value: StreamStart,

View File

@@ -102,7 +102,6 @@ async def setup_test_data(server):
"value": "",
"advanced": False,
"description": "Test input field",
"placeholder_values": [],
},
metadata={"position": {"x": 0, "y": 0}},
)
@@ -242,7 +241,6 @@ async def setup_llm_test_data(server):
"value": "",
"advanced": False,
"description": "Prompt for the LLM",
"placeholder_values": [],
},
metadata={"position": {"x": 0, "y": 0}},
)
@@ -396,7 +394,6 @@ async def setup_firecrawl_test_data(server):
"value": "",
"advanced": False,
"description": "URL for Firecrawl to scrape",
"placeholder_values": [],
},
metadata={"position": {"x": 0, "y": 0}},
)

View File

@@ -4,6 +4,8 @@ import logging
import re
from typing import Any
from backend.data.dynamic_fields import DICT_SPLIT
from .helpers import (
AGENT_EXECUTOR_BLOCK_ID,
MCP_TOOL_BLOCK_ID,
@@ -1536,8 +1538,8 @@ class AgentFixer:
for link in links:
sink_name = link.get("sink_name", "")
if "_#_" in sink_name:
parent, child = sink_name.split("_#_", 1)
if DICT_SPLIT in sink_name:
parent, child = sink_name.split(DICT_SPLIT, 1)
# Check if child is a numeric index (invalid for _#_ notation)
if child.isdigit():

View File

@@ -4,6 +4,8 @@ import re
import uuid
from typing import Any
from backend.data.dynamic_fields import DICT_SPLIT
from .blocks import get_blocks_as_dicts
__all__ = [
@@ -51,8 +53,8 @@ def generate_uuid() -> str:
def get_defined_property_type(schema: dict[str, Any], name: str) -> str | None:
"""Get property type from a schema, handling nested `_#_` notation."""
if "_#_" in name:
parent, child = name.split("_#_", 1)
if DICT_SPLIT in name:
parent, child = name.split(DICT_SPLIT, 1)
parent_schema = schema.get(parent, {})
if "properties" in parent_schema and isinstance(
parent_schema["properties"], dict

View File

@@ -5,6 +5,8 @@ import logging
import re
from typing import Any
from backend.data.dynamic_fields import DICT_SPLIT
from .helpers import (
AGENT_EXECUTOR_BLOCK_ID,
AGENT_INPUT_BLOCK_ID,
@@ -256,95 +258,6 @@ class AgentValidator:
return valid
def validate_nested_sink_links(
self,
agent: AgentDict,
blocks: list[dict[str, Any]],
node_lookup: dict[str, dict[str, Any]] | None = None,
) -> bool:
"""
Validate nested sink links (links with _#_ notation).
Returns True if all nested links are valid, False otherwise.
"""
valid = True
block_input_schemas = {
block.get("id", ""): block.get("inputSchema", {}).get("properties", {})
for block in blocks
}
block_names = {
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
}
if node_lookup is None:
node_lookup = self._build_node_lookup(agent)
for link in agent.get("links", []):
sink_name = link.get("sink_name", "")
sink_id = link.get("sink_id")
if not sink_name or not sink_id:
continue
if "_#_" in sink_name:
parent, child = sink_name.split("_#_", 1)
sink_node = node_lookup.get(sink_id)
if not sink_node:
continue
block_id = sink_node.get("block_id")
input_props = block_input_schemas.get(block_id, {})
parent_schema = input_props.get(parent)
if not parent_schema:
block_name = block_names.get(block_id, "Unknown Block")
self.add_error(
f"Invalid nested sink link '{sink_name}' for "
f"node '{sink_id}' (block "
f"'{block_name}' - {block_id}): Parent property "
f"'{parent}' does not exist in the block's "
f"input schema."
)
valid = False
continue
# Check if additionalProperties is allowed either directly
# or via anyOf
allows_additional_properties = parent_schema.get(
"additionalProperties", False
)
# Check anyOf for additionalProperties
if not allows_additional_properties and "anyOf" in parent_schema:
any_of_schemas = parent_schema.get("anyOf", [])
if isinstance(any_of_schemas, list):
for schema_option in any_of_schemas:
if isinstance(schema_option, dict) and schema_option.get(
"additionalProperties"
):
allows_additional_properties = True
break
if not allows_additional_properties:
if not (
isinstance(parent_schema, dict)
and "properties" in parent_schema
and isinstance(parent_schema["properties"], dict)
and child in parent_schema["properties"]
):
block_name = block_names.get(block_id, "Unknown Block")
self.add_error(
f"Invalid nested sink link '{sink_name}' "
f"for node '{link.get('sink_id', '')}' (block "
f"'{block_name}' - {block_id}): Child "
f"property '{child}' does not exist in "
f"parent '{parent}' schema. Available "
f"properties: "
f"{list(parent_schema.get('properties', {}).keys())}"
)
valid = False
return valid
def validate_prompt_double_curly_braces_spaces(self, agent: AgentDict) -> bool:
"""
Validate that prompt parameters do not contain spaces in double curly
@@ -471,8 +384,8 @@ class AgentValidator:
output_props = block_output_schemas.get(block_id, {})
# Handle nested source names (with _#_ notation)
if "_#_" in source_name:
parent, child = source_name.split("_#_", 1)
if DICT_SPLIT in source_name:
parent, child = source_name.split(DICT_SPLIT, 1)
parent_schema = output_props.get(parent)
if not parent_schema:
@@ -553,6 +466,195 @@ class AgentValidator:
return valid
def validate_sink_input_existence(
self,
agent: AgentDict,
blocks: list[dict[str, Any]],
node_lookup: dict[str, dict[str, Any]] | None = None,
) -> bool:
"""
Validate that all sink_names in links and input_default keys in nodes
exist in the corresponding block's input schema.
Checks that for each link the sink_name references a valid input
property in the sink block's inputSchema, and that every key in a
node's input_default is a recognised input property. Also handles
nested inputs with _#_ notation and dynamic schemas for
AgentExecutorBlock.
Args:
agent: The agent dictionary to validate
blocks: List of available blocks with their schemas
node_lookup: Optional pre-built node-id → node dict
Returns:
True if all sink input fields exist, False otherwise
"""
valid = True
block_input_schemas = {
block.get("id", ""): block.get("inputSchema", {}).get("properties", {})
for block in blocks
}
block_names = {
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
}
if node_lookup is None:
node_lookup = self._build_node_lookup(agent)
def get_input_props(node: dict[str, Any]) -> dict[str, Any]:
block_id = node.get("block_id", "")
if block_id == AGENT_EXECUTOR_BLOCK_ID:
input_default = node.get("input_default", {})
dynamic_input_schema = input_default.get("input_schema", {})
if not isinstance(dynamic_input_schema, dict):
dynamic_input_schema = {}
dynamic_props = dynamic_input_schema.get("properties", {})
if not isinstance(dynamic_props, dict):
dynamic_props = {}
static_props = block_input_schemas.get(block_id, {})
return {**static_props, **dynamic_props}
return block_input_schemas.get(block_id, {})
def check_nested_input(
input_props: dict[str, Any],
field_name: str,
context: str,
block_name: str,
block_id: str,
) -> bool:
parent, child = field_name.split(DICT_SPLIT, 1)
parent_schema = input_props.get(parent)
if not parent_schema:
self.add_error(
f"{context}: Parent property '{parent}' does not "
f"exist in block '{block_name}' ({block_id}) input "
f"schema."
)
return False
allows_additional = parent_schema.get("additionalProperties", False)
# Only anyOf is checked here because Pydantic's JSON schema
# emits optional/union fields via anyOf. allOf and oneOf are
# not currently used by any block's dict-typed inputs, so
# false positives from them are not a concern in practice.
if not allows_additional and "anyOf" in parent_schema:
for schema_option in parent_schema.get("anyOf", []):
if not isinstance(schema_option, dict):
continue
if schema_option.get("additionalProperties"):
allows_additional = True
break
items_schema = schema_option.get("items")
if isinstance(items_schema, dict) and items_schema.get(
"additionalProperties"
):
allows_additional = True
break
if not allows_additional:
if not (
isinstance(parent_schema, dict)
and "properties" in parent_schema
and isinstance(parent_schema["properties"], dict)
and child in parent_schema["properties"]
):
available = (
list(parent_schema.get("properties", {}).keys())
if isinstance(parent_schema, dict)
else []
)
self.add_error(
f"{context}: Child property '{child}' does not "
f"exist in parent '{parent}' of block "
f"'{block_name}' ({block_id}) input schema. "
f"Available properties: {available}"
)
return False
return True
for link in agent.get("links", []):
sink_id = link.get("sink_id")
sink_name = link.get("sink_name", "")
link_id = link.get("id", "Unknown")
if not sink_name:
# Missing sink_name is caught by validate_data_type_compatibility
continue
sink_node = node_lookup.get(sink_id)
if not sink_node:
# Already caught by validate_link_node_references
continue
block_id = sink_node.get("block_id", "")
block_name = block_names.get(block_id, "Unknown Block")
input_props = get_input_props(sink_node)
context = (
f"Invalid sink input field '{sink_name}' in link "
f"'{link_id}' to node '{sink_id}'"
)
if DICT_SPLIT in sink_name:
if not check_nested_input(
input_props, sink_name, context, block_name, block_id
):
valid = False
else:
if sink_name not in input_props:
available_inputs = list(input_props.keys())
self.add_error(
f"{context} (block '{block_name}' - {block_id}): "
f"Input property '{sink_name}' does not exist in "
f"the block's input schema. "
f"Available inputs: {available_inputs}"
)
valid = False
for node in agent.get("nodes", []):
node_id = node.get("id")
block_id = node.get("block_id", "")
block_name = block_names.get(block_id, "Unknown Block")
input_default = node.get("input_default", {})
if not isinstance(input_default, dict) or not input_default:
continue
if (
block_id not in block_input_schemas
and block_id != AGENT_EXECUTOR_BLOCK_ID
):
continue
input_props = get_input_props(node)
for key in input_default:
if key == "credentials":
continue
context = (
f"Node '{node_id}' (block '{block_name}' - {block_id}) "
f"has unknown input_default key '{key}'"
)
if DICT_SPLIT in key:
if not check_nested_input(
input_props, key, context, block_name, block_id
):
valid = False
else:
if key not in input_props:
available_inputs = list(input_props.keys())
self.add_error(
f"{context} which does not exist in the "
f"block's input schema. "
f"Available inputs: {available_inputs}"
)
valid = False
return valid
def validate_io_blocks(self, agent: AgentDict) -> bool:
"""
Validate that the agent has at least one AgentInputBlock and one
@@ -998,14 +1100,14 @@ class AgentValidator:
"Data type compatibility",
self.validate_data_type_compatibility(agent, blocks, node_lookup),
),
(
"Nested sink links",
self.validate_nested_sink_links(agent, blocks, node_lookup),
),
(
"Source output existence",
self.validate_source_output_existence(agent, blocks, node_lookup),
),
(
"Sink input existence",
self.validate_sink_input_existence(agent, blocks, node_lookup),
),
(
"Prompt double curly braces spaces",
self.validate_prompt_double_curly_braces_spaces(agent),

View File

@@ -331,43 +331,6 @@ class TestValidatePromptDoubleCurlyBracesSpaces:
assert any("spaces" in e for e in v.errors)
# ============================================================================
# validate_nested_sink_links
# ============================================================================
class TestValidateNestedSinkLinks:
def test_valid_nested_link_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {
"config": {
"type": "object",
"properties": {"key": {"type": "string"}},
}
},
"required": [],
},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(sink_id="n1", sink_name="config_#_key", source_id="n2")
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_nested_sink_links(agent, [block]) is True
def test_invalid_parent_fails(self):
v = AgentValidator()
block = _make_block(block_id="b1")
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(sink_id="n1", sink_name="nonexistent_#_key", source_id="n2")
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_nested_sink_links(agent, [block]) is False
assert any("does not exist" in e for e in v.errors)
# ============================================================================
# validate_agent_executor_block_schemas
# ============================================================================
@@ -595,11 +558,28 @@ class TestValidate:
input_block = _make_block(
block_id=AGENT_INPUT_BLOCK_ID,
name="AgentInputBlock",
input_schema={
"properties": {
"name": {"type": "string"},
"title": {"type": "string"},
"value": {},
"description": {"type": "string"},
},
"required": ["name"],
},
output_schema={"properties": {"result": {}}},
)
output_block = _make_block(
block_id=AGENT_OUTPUT_BLOCK_ID,
name="AgentOutputBlock",
input_schema={
"properties": {
"name": {"type": "string"},
"title": {"type": "string"},
"value": {},
},
"required": ["name"],
},
)
input_node = _make_node(
node_id="n-in",
@@ -650,6 +630,201 @@ class TestValidate:
assert "AgentOutputBlock" in error_message
class TestValidateSinkInputExistence:
"""Tests for validate_sink_input_existence."""
def test_valid_sink_name_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={"properties": {"url": {"type": "string"}}, "required": []},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(
source_id="src", source_name="out", sink_id="n1", sink_name="url"
)
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_sink_input_existence(agent, [block]) is True
def test_invalid_sink_name_fails(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={"properties": {"url": {"type": "string"}}, "required": []},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(
source_id="src", source_name="out", sink_id="n1", sink_name="nonexistent"
)
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_sink_input_existence(agent, [block]) is False
assert any("nonexistent" in e for e in v.errors)
def test_valid_nested_link_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {
"config": {
"type": "object",
"properties": {"key": {"type": "string"}},
}
},
"required": [],
},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(
source_id="src",
source_name="out",
sink_id="n1",
sink_name="config_#_key",
)
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_sink_input_existence(agent, [block]) is True
def test_invalid_nested_child_fails(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {
"config": {
"type": "object",
"properties": {"key": {"type": "string"}},
}
},
"required": [],
},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(
source_id="src",
source_name="out",
sink_id="n1",
sink_name="config_#_missing",
)
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_sink_input_existence(agent, [block]) is False
def test_unknown_input_default_key_fails(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={"properties": {"url": {"type": "string"}}, "required": []},
)
node = _make_node(
node_id="n1", block_id="b1", input_default={"nonexistent_key": "value"}
)
agent = _make_agent(nodes=[node])
assert v.validate_sink_input_existence(agent, [block]) is False
assert any("nonexistent_key" in e for e in v.errors)
def test_credentials_key_skipped(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={"properties": {"url": {"type": "string"}}, "required": []},
)
node = _make_node(
node_id="n1",
block_id="b1",
input_default={
"url": "http://example.com",
"credentials": {"api_key": "x"},
},
)
agent = _make_agent(nodes=[node])
assert v.validate_sink_input_existence(agent, [block]) is True
def test_agent_executor_dynamic_schema_passes(self):
v = AgentValidator()
block = _make_block(
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_schema={
"properties": {
"graph_id": {"type": "string"},
"input_schema": {"type": "object"},
},
"required": ["graph_id"],
},
)
node = _make_node(
node_id="n1",
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": "abc",
"input_schema": {
"properties": {"query": {"type": "string"}},
"required": [],
},
},
)
link = _make_link(
source_id="src",
source_name="out",
sink_id="n1",
sink_name="query",
)
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_sink_input_existence(agent, [block]) is True
def test_input_default_nested_invalid_child_fails(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {
"config": {
"type": "object",
"properties": {"key": {"type": "string"}},
}
},
"required": [],
},
)
node = _make_node(
node_id="n1",
block_id="b1",
input_default={"config_#_invalid_child": "value"},
)
agent = _make_agent(nodes=[node])
assert v.validate_sink_input_existence(agent, [block]) is False
assert any("invalid_child" in e for e in v.errors)
def test_input_default_nested_valid_child_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {
"config": {
"type": "object",
"properties": {"key": {"type": "string"}},
}
},
"required": [],
},
)
node = _make_node(
node_id="n1",
block_id="b1",
input_default={"config_#_key": "value"},
)
agent = _make_agent(nodes=[node])
assert v.validate_sink_input_existence(agent, [block]) is True
class TestValidateMCPToolBlocks:
"""Tests for validate_mcp_tool_blocks."""

View File

@@ -10,7 +10,12 @@ from pydantic import BaseModel, Field, field_validator
from backend.api.features.library.model import LibraryAgent
from backend.copilot.model import ChatSession
from backend.data.db_accessors import execution_db, library_db
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
from backend.data.execution import (
ExecutionStatus,
GraphExecution,
GraphExecutionMeta,
GraphExecutionWithNodes,
)
from .base import BaseTool
from .execution_utils import TERMINAL_STATUSES, wait_for_execution
@@ -35,6 +40,7 @@ class AgentOutputInput(BaseModel):
execution_id: str = ""
run_time: str = "latest"
wait_if_running: int = Field(default=0, ge=0, le=300)
show_execution_details: bool = False
@field_validator(
"agent_name",
@@ -146,6 +152,10 @@ class AgentOutputTool(BaseTool):
"minimum": 0,
"maximum": 300,
},
"show_execution_details": {
"type": "boolean",
"description": "If true, include full node-by-node execution trace (inputs, outputs, status, timing for each node). Useful for debugging agent wiring. Default: false.",
},
},
"required": [],
}
@@ -226,13 +236,19 @@ class AgentOutputTool(BaseTool):
time_start: datetime | None,
time_end: datetime | None,
include_running: bool = False,
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
include_node_executions: bool = False,
) -> tuple[
GraphExecution | GraphExecutionWithNodes | None,
list[GraphExecutionMeta],
str | None,
]:
"""
Fetch execution(s) based on filters.
Returns (single_execution, available_executions_meta, error_message).
Args:
include_running: If True, also look for running/queued executions (for waiting)
include_node_executions: If True, include node-by-node execution details
"""
exec_db = execution_db()
@@ -241,7 +257,7 @@ class AgentOutputTool(BaseTool):
execution = await exec_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
include_node_executions=include_node_executions,
)
if not execution:
return None, [], f"Execution '{execution_id}' not found"
@@ -279,7 +295,7 @@ class AgentOutputTool(BaseTool):
full_execution = await exec_db.get_graph_execution(
user_id=user_id,
execution_id=executions[0].id,
include_node_executions=False,
include_node_executions=include_node_executions,
)
return full_execution, [], None
@@ -287,14 +303,14 @@ class AgentOutputTool(BaseTool):
full_execution = await exec_db.get_graph_execution(
user_id=user_id,
execution_id=executions[0].id,
include_node_executions=False,
include_node_executions=include_node_executions,
)
return full_execution, executions, None
def _build_response(
self,
agent: LibraryAgent,
execution: GraphExecution | None,
execution: GraphExecution | GraphExecutionWithNodes | None,
available_executions: list[GraphExecutionMeta],
session_id: str | None,
) -> AgentOutputResponse:
@@ -312,6 +328,21 @@ class AgentOutputTool(BaseTool):
total_executions=0,
)
node_executions_data = None
if isinstance(execution, GraphExecutionWithNodes):
node_executions_data = [
{
"node_id": ne.node_id,
"block_id": ne.block_id,
"status": ne.status.value,
"input_data": ne.input_data,
"output_data": dict(ne.output_data),
"start_time": ne.start_time.isoformat() if ne.start_time else None,
"end_time": ne.end_time.isoformat() if ne.end_time else None,
}
for ne in execution.node_executions
]
execution_info = ExecutionOutputInfo(
execution_id=execution.id,
status=execution.status.value,
@@ -319,6 +350,7 @@ class AgentOutputTool(BaseTool):
ended_at=execution.ended_at,
outputs=dict(execution.outputs),
inputs_summary=execution.inputs if execution.inputs else None,
node_executions=node_executions_data,
)
available_list = None
@@ -428,7 +460,7 @@ class AgentOutputTool(BaseTool):
execution = await execution_db().get_graph_execution(
user_id=user_id,
execution_id=input_data.execution_id,
include_node_executions=False,
include_node_executions=input_data.show_execution_details,
)
if not execution:
return ErrorResponse(
@@ -484,6 +516,7 @@ class AgentOutputTool(BaseTool):
time_start=time_start,
time_end=time_end,
include_running=wait_timeout > 0,
include_node_executions=input_data.show_execution_details,
)
if exec_error:

View File

@@ -119,8 +119,11 @@ class ContinueRunBlockTool(BaseTool):
)
logger.info(
f"Continuing block {block.name} ({block_id}) for user {user_id} "
f"with review_id={review_id}"
"Continuing block %s (%s) for user %s with review_id=%s",
block.name,
block_id,
user_id,
review_id,
)
matched_creds, missing_creds = await resolve_block_credentials(
@@ -132,6 +135,9 @@ class ContinueRunBlockTool(BaseTool):
session_id=session_id,
)
# dry_run=False is safe here: run_block's dry-run fast-path (line ~241)
# skips HITL entirely, so continue_run_block is never called during a
# dry run — only real executions reach the human review gate.
result = await execute_block(
block=block,
block_id=block_id,
@@ -140,6 +146,7 @@ class ContinueRunBlockTool(BaseTool):
session_id=session_id,
node_exec_id=review_id,
matched_credentials=matched_creds,
dry_run=False,
)
# Delete review record after successful execution (one-time use)

View File

@@ -21,6 +21,7 @@ from backend.data.credit import UsageTransactionMetadata
from backend.data.db_accessors import credit_db, review_db, workspace_db
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.executor.simulator import simulate_block
from backend.executor.utils import block_usage_cost
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError, InsufficientBalanceError
@@ -80,6 +81,7 @@ async def execute_block(
node_exec_id: str,
matched_credentials: dict[str, CredentialsMetaInput],
sensitive_action_safe_mode: bool = False,
dry_run: bool = False,
) -> ToolResponseBase:
"""Execute a block with full context setup, credential injection, and error handling.
@@ -89,6 +91,49 @@ async def execute_block(
Returns:
BlockOutputResponse on success, ErrorResponse on failure.
"""
# Dry-run path: simulate the block with an LLM, no real execution.
# HITL review is intentionally skipped — no real execution occurs.
if dry_run:
try:
# Coerce types to match the block's input schema, same as real execution.
# This ensures the simulated preview is consistent with real execution
# (e.g., "42" → 42, string booleans → bool, enum defaults applied).
coerce_inputs_to_schema(input_data, block.input_schema)
outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in simulate_block(block, input_data):
outputs[output_name].append(output_data)
# simulator signals internal failure via ("error", "[SIMULATOR ERROR …]")
sim_error = outputs.get("error", [])
if (
sim_error
and isinstance(sim_error[0], str)
and sim_error[0].startswith("[SIMULATOR ERROR")
):
return ErrorResponse(
message=sim_error[0],
error=sim_error[0],
session_id=session_id,
)
return BlockOutputResponse(
message=(
f"[DRY RUN] Block '{block.name}' simulated successfully "
"— no real execution occurred."
),
block_id=block_id,
block_name=block.name,
outputs=dict(outputs),
success=True,
is_dry_run=True,
session_id=session_id,
)
except Exception as e:
logger.error("Dry-run simulation failed: %s", e, exc_info=True)
return ErrorResponse(
message=f"Dry-run simulation failed: {e}",
error=str(e),
session_id=session_id,
)
try:
workspace = await workspace_db().get_or_create_workspace(user_id)
@@ -292,6 +337,7 @@ async def prepare_block_for_execution(
user_id: str,
session: ChatSession,
session_id: str,
dry_run: bool = False,
) -> "BlockPreparation | ToolResponseBase":
"""Validate and prepare a block for execution.
@@ -379,7 +425,7 @@ async def prepare_block_for_execution(
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
if missing_credentials:
if missing_credentials and not dry_run:
credentials_fields_info = _resolve_discriminated_credentials(block, input_data)
missing_creds_dict = build_missing_credentials_from_field_info(
credentials_fields_info, set(matched_credentials.keys())
@@ -491,7 +537,7 @@ async def check_hitl_review(
)
synthetic_node_exec_id = (
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}" f"{uuid.uuid4().hex[:8]}"
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}{uuid.uuid4().hex[:8]}"
)
review_context = ExecutionContext(
@@ -536,7 +582,16 @@ def _resolve_discriminated_credentials(
block: AnyBlockSchema,
input_data: dict[str, Any],
) -> dict[str, CredentialsFieldInfo]:
"""Resolve credential requirements, applying discriminator logic where needed."""
"""Resolve credential requirements, applying discriminator logic where needed.
Handles two discrimination modes:
1. **Provider-based** (``discriminator_mapping`` is set): the discriminator
field value selects the provider (e.g. an AI model name -> provider).
2. **URL/host-based** (``discriminator`` is set but ``discriminator_mapping``
is ``None``): the discriminator field value (typically a URL) is added to
``discriminator_values`` so that host-scoped credential matching can
compare the credential's host against the target URL.
"""
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return {}
@@ -546,23 +601,42 @@ def _resolve_discriminated_credentials(
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
if field_info.discriminator:
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(field_info.discriminator)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
effective_field_info.discriminator_values.add(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
if discriminator_value is not None:
if field_info.discriminator_mapping:
# Provider-based discrimination (e.g. model -> provider)
if discriminator_value in field_info.discriminator_mapping:
effective_field_info = field_info.discriminate(
discriminator_value
)
effective_field_info.discriminator_values.add(
discriminator_value
)
# Model names are safe to log (not PII); URLs are
# intentionally omitted in the host-based branch below.
logger.debug(
"Discriminated provider for %s: %s -> %s",
field_name,
discriminator_value,
effective_field_info.provider,
)
else:
# URL/host-based discrimination (e.g. url -> host matching).
# Deep copy to avoid mutating the cached schema-level
# field_info (model_copy() is shallow — the mutable set
# would be shared).
effective_field_info = field_info.model_copy(deep=True)
effective_field_info.discriminator_values.add(discriminator_value)
logger.debug(
"Added discriminator value for host matching on %s",
field_name,
)
resolved[field_name] = effective_field_info

View File

@@ -0,0 +1,916 @@
"""Tests for credential resolution across all credential types in the CoPilot.
These tests verify that:
1. `_resolve_discriminated_credentials` correctly populates discriminator_values
for URL-based (host-scoped) and provider-based (api_key) credential fields.
2. `find_matching_credential` correctly matches credentials for all types:
APIKeyCredentials, OAuth2Credentials, UserPasswordCredentials, and
HostScopedCredentials.
3. The full `resolve_block_credentials` flow correctly resolves matching
credentials or reports them as missing for each credential type.
4. `RunBlockTool._execute` end-to-end tests return correct response types.
"""
from unittest.mock import AsyncMock, patch
from pydantic import SecretStr
from backend.blocks.http import SendAuthenticatedWebRequestBlock
from backend.data.model import (
APIKeyCredentials,
CredentialsFieldInfo,
CredentialsType,
HostScopedCredentials,
OAuth2Credentials,
UserPasswordCredentials,
)
from backend.integrations.providers import ProviderName
from ._test_data import make_session
from .helpers import _resolve_discriminated_credentials, resolve_block_credentials
from .models import BlockDetailsResponse, SetupRequirementsResponse
from .run_block import RunBlockTool
from .utils import find_matching_credential
_TEST_USER_ID = "test-user-http-cred"
# Properly typed constants to avoid type: ignore on CredentialsFieldInfo construction.
_HOST_SCOPED_TYPES: frozenset[CredentialsType] = frozenset(["host_scoped"])
_API_KEY_TYPES: frozenset[CredentialsType] = frozenset(["api_key"])
_OAUTH2_TYPES: frozenset[CredentialsType] = frozenset(["oauth2"])
_USER_PASSWORD_TYPES: frozenset[CredentialsType] = frozenset(["user_password"])
# ---------------------------------------------------------------------------
# _resolve_discriminated_credentials tests
# ---------------------------------------------------------------------------
class TestResolveDiscriminatedCredentials:
"""Tests for _resolve_discriminated_credentials with URL-based discrimination."""
def _get_auth_block(self):
return SendAuthenticatedWebRequestBlock()
def test_url_discriminator_populates_discriminator_values(self):
"""When input_data contains a URL, discriminator_values should include it."""
block = self._get_auth_block()
input_data = {"url": "https://api.example.com/v1/data"}
result = _resolve_discriminated_credentials(block, input_data)
assert "credentials" in result
field_info = result["credentials"]
assert "https://api.example.com/v1/data" in field_info.discriminator_values
def test_url_discriminator_without_url_keeps_empty_values(self):
"""When no URL is provided, discriminator_values should remain empty."""
block = self._get_auth_block()
input_data = {}
result = _resolve_discriminated_credentials(block, input_data)
assert "credentials" in result
field_info = result["credentials"]
assert len(field_info.discriminator_values) == 0
def test_url_discriminator_does_not_mutate_original_field_info(self):
"""The original block schema field_info must not be mutated."""
block = self._get_auth_block()
# Grab a reference to the original schema-level field_info
original_info = block.input_schema.get_credentials_fields_info()["credentials"]
# Call with a URL, which adds to discriminator_values on the copy
_resolve_discriminated_credentials(
block, {"url": "https://api.example.com/v1/data"}
)
# The original object must remain unchanged
assert len(original_info.discriminator_values) == 0
# And a fresh call without URL should also return empty values
result = _resolve_discriminated_credentials(block, {})
field_info = result["credentials"]
assert len(field_info.discriminator_values) == 0
def test_url_discriminator_preserves_provider_and_type(self):
"""Provider and supported_types should be preserved after URL discrimination."""
block = self._get_auth_block()
input_data = {"url": "https://api.example.com/v1/data"}
result = _resolve_discriminated_credentials(block, input_data)
field_info = result["credentials"]
assert ProviderName.HTTP in field_info.provider
assert "host_scoped" in field_info.supported_types
def test_provider_discriminator_still_works(self):
"""Verify provider-based discrimination (e.g. model -> provider) is preserved.
The refactored conditional in _resolve_discriminated_credentials split the
original single ``if`` into nested ``if/else`` branches. This test ensures
the provider-based path still narrows the provider correctly.
"""
from backend.blocks.llm import AITextGeneratorBlock
block = AITextGeneratorBlock()
input_data = {"model": "gpt-4o-mini"}
result = _resolve_discriminated_credentials(block, input_data)
assert "credentials" in result
field_info = result["credentials"]
# Should narrow provider to openai
assert ProviderName.OPENAI in field_info.provider
assert "gpt-4o-mini" in field_info.discriminator_values
# ---------------------------------------------------------------------------
# find_matching_credential tests (host-scoped)
# ---------------------------------------------------------------------------
class TestFindMatchingHostScopedCredential:
"""Tests for find_matching_credential with host-scoped credentials."""
def _make_host_scoped_cred(
self, host: str, cred_id: str = "test-cred-id"
) -> HostScopedCredentials:
return HostScopedCredentials(
id=cred_id,
provider="http",
host=host,
headers={"Authorization": SecretStr("Bearer test-token")},
title=f"Cred for {host}",
)
def _make_field_info(
self, discriminator_values: set | None = None
) -> CredentialsFieldInfo:
return CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.HTTP]),
credentials_types=_HOST_SCOPED_TYPES,
credentials_scopes=None,
discriminator="url",
discriminator_values=discriminator_values or set(),
)
def test_matches_credential_for_correct_host(self):
"""A host-scoped credential matching the URL host should be returned."""
cred = self._make_host_scoped_cred("api.example.com")
field_info = self._make_field_info({"https://api.example.com/v1/data"})
result = find_matching_credential([cred], field_info)
assert result is not None
assert result.id == cred.id
def test_rejects_credential_for_wrong_host(self):
"""A host-scoped credential for a different host should not match."""
cred = self._make_host_scoped_cred("api.github.com")
field_info = self._make_field_info({"https://api.stripe.com/v1/charges"})
result = find_matching_credential([cred], field_info)
assert result is None
def test_matches_any_when_no_discriminator_values(self):
"""With empty discriminator_values, any host-scoped credential matches.
Note: this tests the current fallback behavior in _credential_is_for_host()
where empty discriminator_values means "no host constraint" and any
host-scoped credential is accepted. This is by design for the case where
the target URL is not yet known (e.g. schema preview with empty input).
"""
cred = self._make_host_scoped_cred("api.anything.com")
field_info = self._make_field_info(set())
result = find_matching_credential([cred], field_info)
assert result is not None
def test_wildcard_host_matching(self):
"""Wildcard host (*.example.com) should match subdomains."""
cred = self._make_host_scoped_cred("*.example.com")
field_info = self._make_field_info({"https://api.example.com/v1/data"})
result = find_matching_credential([cred], field_info)
assert result is not None
def test_selects_correct_credential_from_multiple(self):
"""When multiple host-scoped credentials exist, the correct one is selected."""
cred_github = self._make_host_scoped_cred("api.github.com", "github-cred")
cred_stripe = self._make_host_scoped_cred("api.stripe.com", "stripe-cred")
field_info = self._make_field_info({"https://api.stripe.com/v1/charges"})
result = find_matching_credential([cred_github, cred_stripe], field_info)
assert result is not None
assert result.id == "stripe-cred"
# ---------------------------------------------------------------------------
# find_matching_credential tests (api_key)
# ---------------------------------------------------------------------------
class TestFindMatchingAPIKeyCredential:
"""Tests for find_matching_credential with API key credentials."""
def _make_api_key_cred(
self, provider: str = "google_maps", cred_id: str = "test-api-key-id"
) -> APIKeyCredentials:
return APIKeyCredentials(
id=cred_id,
provider=provider,
api_key=SecretStr("sk-test-key-123"),
title=f"API key for {provider}",
expires_at=None,
)
def _make_field_info(
self, provider: ProviderName = ProviderName.GOOGLE_MAPS
) -> CredentialsFieldInfo:
return CredentialsFieldInfo(
credentials_provider=frozenset([provider]),
credentials_types=_API_KEY_TYPES,
credentials_scopes=None,
)
def test_matches_credential_for_correct_provider(self):
"""An API key credential matching the provider should be returned."""
cred = self._make_api_key_cred("google_maps")
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
result = find_matching_credential([cred], field_info)
assert result is not None
assert result.id == cred.id
def test_rejects_credential_for_wrong_provider(self):
"""An API key credential for a different provider should not match."""
cred = self._make_api_key_cred("openai")
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
result = find_matching_credential([cred], field_info)
assert result is None
def test_rejects_credential_for_wrong_type(self):
"""An OAuth2 credential should not match an api_key requirement."""
oauth_cred = OAuth2Credentials(
id="oauth-cred-id",
provider="google_maps",
access_token=SecretStr("mock-token"),
scopes=[],
title="OAuth cred (wrong type)",
)
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
result = find_matching_credential([oauth_cred], field_info)
assert result is None
def test_selects_correct_credential_from_multiple(self):
"""When multiple API key credentials exist, the correct provider is selected."""
cred_maps = self._make_api_key_cred("google_maps", "maps-key")
cred_openai = self._make_api_key_cred("openai", "openai-key")
field_info = self._make_field_info(ProviderName.OPENAI)
result = find_matching_credential([cred_maps, cred_openai], field_info)
assert result is not None
assert result.id == "openai-key"
def test_returns_none_when_no_credentials(self):
"""Should return None when the credential list is empty."""
field_info = self._make_field_info(ProviderName.GOOGLE_MAPS)
result = find_matching_credential([], field_info)
assert result is None
# ---------------------------------------------------------------------------
# find_matching_credential tests (oauth2)
# ---------------------------------------------------------------------------
class TestFindMatchingOAuth2Credential:
"""Tests for find_matching_credential with OAuth2 credentials."""
def _make_oauth2_cred(
self,
provider: str = "google",
scopes: list[str] | None = None,
cred_id: str = "test-oauth2-id",
) -> OAuth2Credentials:
return OAuth2Credentials(
id=cred_id,
provider=provider,
access_token=SecretStr("mock-access-token"),
refresh_token=SecretStr("mock-refresh-token"),
access_token_expires_at=1234567890,
scopes=scopes or [],
title=f"OAuth2 cred for {provider}",
)
def _make_field_info(
self,
provider: ProviderName = ProviderName.GOOGLE,
required_scopes: frozenset[str] | None = None,
) -> CredentialsFieldInfo:
return CredentialsFieldInfo(
credentials_provider=frozenset([provider]),
credentials_types=_OAUTH2_TYPES,
credentials_scopes=required_scopes,
)
def test_matches_credential_for_correct_provider(self):
"""An OAuth2 credential matching the provider should be returned."""
cred = self._make_oauth2_cred("google")
field_info = self._make_field_info(ProviderName.GOOGLE)
result = find_matching_credential([cred], field_info)
assert result is not None
assert result.id == cred.id
def test_rejects_credential_for_wrong_provider(self):
"""An OAuth2 credential for a different provider should not match."""
cred = self._make_oauth2_cred("github")
field_info = self._make_field_info(ProviderName.GOOGLE)
result = find_matching_credential([cred], field_info)
assert result is None
def test_matches_credential_with_required_scopes(self):
"""An OAuth2 credential with all required scopes should match."""
cred = self._make_oauth2_cred(
"google",
scopes=[
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.send",
],
)
field_info = self._make_field_info(
ProviderName.GOOGLE,
required_scopes=frozenset(
["https://www.googleapis.com/auth/gmail.readonly"]
),
)
result = find_matching_credential([cred], field_info)
assert result is not None
def test_rejects_credential_with_insufficient_scopes(self):
"""An OAuth2 credential missing required scopes should not match."""
cred = self._make_oauth2_cred(
"google",
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
)
field_info = self._make_field_info(
ProviderName.GOOGLE,
required_scopes=frozenset(
[
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.send",
]
),
)
result = find_matching_credential([cred], field_info)
assert result is None
def test_matches_credential_when_no_scopes_required(self):
"""An OAuth2 credential should match when no scopes are required."""
cred = self._make_oauth2_cred("google", scopes=[])
field_info = self._make_field_info(ProviderName.GOOGLE)
result = find_matching_credential([cred], field_info)
assert result is not None
def test_selects_correct_credential_from_multiple(self):
"""When multiple OAuth2 credentials exist, the correct one is selected."""
cred_google = self._make_oauth2_cred("google", cred_id="google-cred")
cred_github = self._make_oauth2_cred("github", cred_id="github-cred")
field_info = self._make_field_info(ProviderName.GITHUB)
result = find_matching_credential([cred_google, cred_github], field_info)
assert result is not None
assert result.id == "github-cred"
def test_returns_none_when_no_credentials(self):
"""Should return None when the credential list is empty."""
field_info = self._make_field_info(ProviderName.GOOGLE)
result = find_matching_credential([], field_info)
assert result is None
# ---------------------------------------------------------------------------
# find_matching_credential tests (user_password)
# ---------------------------------------------------------------------------
class TestFindMatchingUserPasswordCredential:
"""Tests for find_matching_credential with user/password credentials."""
def _make_user_password_cred(
self, provider: str = "smtp", cred_id: str = "test-userpass-id"
) -> UserPasswordCredentials:
return UserPasswordCredentials(
id=cred_id,
provider=provider,
username=SecretStr("test-user"),
password=SecretStr("test-pass"),
title=f"Credentials for {provider}",
)
def _make_field_info(
self, provider: ProviderName = ProviderName.SMTP
) -> CredentialsFieldInfo:
return CredentialsFieldInfo(
credentials_provider=frozenset([provider]),
credentials_types=_USER_PASSWORD_TYPES,
credentials_scopes=None,
)
def test_matches_credential_for_correct_provider(self):
"""A user/password credential matching the provider should be returned."""
cred = self._make_user_password_cred("smtp")
field_info = self._make_field_info(ProviderName.SMTP)
result = find_matching_credential([cred], field_info)
assert result is not None
assert result.id == cred.id
def test_rejects_credential_for_wrong_provider(self):
"""A user/password credential for a different provider should not match."""
cred = self._make_user_password_cred("smtp")
field_info = self._make_field_info(ProviderName.HUBSPOT)
result = find_matching_credential([cred], field_info)
assert result is None
def test_rejects_credential_for_wrong_type(self):
"""An API key credential should not match a user_password requirement."""
api_key_cred = APIKeyCredentials(
id="api-key-cred-id",
provider="smtp",
api_key=SecretStr("wrong-type-key"),
title="API key cred (wrong type)",
)
field_info = self._make_field_info(ProviderName.SMTP)
result = find_matching_credential([api_key_cred], field_info)
assert result is None
def test_selects_correct_credential_from_multiple(self):
"""When multiple user/password credentials exist, the correct one is selected."""
cred_smtp = self._make_user_password_cred("smtp", "smtp-cred")
cred_hubspot = self._make_user_password_cred("hubspot", "hubspot-cred")
field_info = self._make_field_info(ProviderName.HUBSPOT)
result = find_matching_credential([cred_smtp, cred_hubspot], field_info)
assert result is not None
assert result.id == "hubspot-cred"
def test_returns_none_when_no_credentials(self):
"""Should return None when the credential list is empty."""
field_info = self._make_field_info(ProviderName.SMTP)
result = find_matching_credential([], field_info)
assert result is None
# ---------------------------------------------------------------------------
# find_matching_credential tests (mixed credential types)
# ---------------------------------------------------------------------------
class TestFindMatchingCredentialMixedTypes:
"""Tests that find_matching_credential correctly filters by type in a mixed list."""
def test_selects_api_key_from_mixed_list(self):
"""API key requirement should skip OAuth2 and user_password credentials."""
oauth_cred = OAuth2Credentials(
id="oauth-id",
provider="openai",
access_token=SecretStr("token"),
scopes=[],
)
userpass_cred = UserPasswordCredentials(
id="userpass-id",
provider="openai",
username=SecretStr("user"),
password=SecretStr("pass"),
)
api_key_cred = APIKeyCredentials(
id="apikey-id",
provider="openai",
api_key=SecretStr("sk-key"),
)
field_info = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.OPENAI]),
credentials_types=_API_KEY_TYPES,
credentials_scopes=None,
)
result = find_matching_credential(
[oauth_cred, userpass_cred, api_key_cred], field_info
)
assert result is not None
assert result.id == "apikey-id"
def test_selects_oauth2_from_mixed_list(self):
"""OAuth2 requirement should skip API key and user_password credentials."""
api_key_cred = APIKeyCredentials(
id="apikey-id",
provider="google",
api_key=SecretStr("key"),
)
userpass_cred = UserPasswordCredentials(
id="userpass-id",
provider="google",
username=SecretStr("user"),
password=SecretStr("pass"),
)
oauth_cred = OAuth2Credentials(
id="oauth-id",
provider="google",
access_token=SecretStr("token"),
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
)
field_info = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.GOOGLE]),
credentials_types=_OAUTH2_TYPES,
credentials_scopes=frozenset(
["https://www.googleapis.com/auth/gmail.readonly"]
),
)
result = find_matching_credential(
[api_key_cred, userpass_cred, oauth_cred], field_info
)
assert result is not None
assert result.id == "oauth-id"
def test_selects_user_password_from_mixed_list(self):
"""User/password requirement should skip API key and OAuth2 credentials."""
api_key_cred = APIKeyCredentials(
id="apikey-id",
provider="smtp",
api_key=SecretStr("key"),
)
oauth_cred = OAuth2Credentials(
id="oauth-id",
provider="smtp",
access_token=SecretStr("token"),
scopes=[],
)
userpass_cred = UserPasswordCredentials(
id="userpass-id",
provider="smtp",
username=SecretStr("user"),
password=SecretStr("pass"),
)
field_info = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.SMTP]),
credentials_types=_USER_PASSWORD_TYPES,
credentials_scopes=None,
)
result = find_matching_credential(
[api_key_cred, oauth_cred, userpass_cred], field_info
)
assert result is not None
assert result.id == "userpass-id"
def test_returns_none_when_only_wrong_types_available(self):
"""Should return None when all available creds have the wrong type."""
oauth_cred = OAuth2Credentials(
id="oauth-id",
provider="google_maps",
access_token=SecretStr("token"),
scopes=[],
)
field_info = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.GOOGLE_MAPS]),
credentials_types=_API_KEY_TYPES,
credentials_scopes=None,
)
result = find_matching_credential([oauth_cred], field_info)
assert result is None
# ---------------------------------------------------------------------------
# resolve_block_credentials tests (integration — all credential types)
# ---------------------------------------------------------------------------
class TestResolveBlockCredentials:
"""Integration tests for resolve_block_credentials across credential types."""
async def test_matches_host_scoped_credential_for_url(self):
"""resolve_block_credentials should match a host-scoped cred for the given URL."""
block = SendAuthenticatedWebRequestBlock()
input_data = {"url": "https://api.example.com/v1/data"}
mock_cred = HostScopedCredentials(
id="matching-cred-id",
provider="http",
host="api.example.com",
headers={"Authorization": SecretStr("Bearer token")},
title="Example API Cred",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[mock_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert "credentials" in matched
assert matched["credentials"].id == "matching-cred-id"
assert len(missing) == 0
async def test_reports_missing_when_no_matching_host(self):
"""resolve_block_credentials should report missing creds when host doesn't match."""
block = SendAuthenticatedWebRequestBlock()
input_data = {"url": "https://api.stripe.com/v1/charges"}
wrong_host_cred = HostScopedCredentials(
id="wrong-cred-id",
provider="http",
host="api.github.com",
headers={"Authorization": SecretStr("Bearer token")},
title="GitHub API Cred",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[wrong_host_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert len(matched) == 0
assert len(missing) == 1
async def test_reports_missing_when_no_credentials(self):
"""resolve_block_credentials should report missing when user has no creds at all."""
block = SendAuthenticatedWebRequestBlock()
input_data = {"url": "https://api.example.com/v1/data"}
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert len(matched) == 0
assert len(missing) == 1
async def test_matches_api_key_credential_for_llm_block(self):
"""resolve_block_credentials should match an API key cred for an LLM block."""
from backend.blocks.llm import AITextGeneratorBlock
block = AITextGeneratorBlock()
input_data = {"model": "gpt-4o-mini"}
mock_cred = APIKeyCredentials(
id="openai-key-id",
provider="openai",
api_key=SecretStr("sk-test-key"),
title="OpenAI API Key",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[mock_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert "credentials" in matched
assert matched["credentials"].id == "openai-key-id"
assert len(missing) == 0
async def test_reports_missing_api_key_for_wrong_provider(self):
"""resolve_block_credentials should report missing when API key provider mismatches."""
from backend.blocks.llm import AITextGeneratorBlock
block = AITextGeneratorBlock()
input_data = {"model": "gpt-4o-mini"}
wrong_provider_cred = APIKeyCredentials(
id="wrong-key-id",
provider="google_maps",
api_key=SecretStr("sk-wrong"),
title="Google Maps Key",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[wrong_provider_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert len(matched) == 0
assert len(missing) == 1
async def test_matches_oauth2_credential_for_google_block(self):
"""resolve_block_credentials should match an OAuth2 cred for a Google block."""
from backend.blocks.google.gmail import GmailReadBlock
block = GmailReadBlock()
input_data = {}
mock_cred = OAuth2Credentials(
id="google-oauth-id",
provider="google",
access_token=SecretStr("mock-token"),
refresh_token=SecretStr("mock-refresh"),
access_token_expires_at=9999999999,
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
title="Google OAuth",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[mock_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert "credentials" in matched
assert matched["credentials"].id == "google-oauth-id"
assert len(missing) == 0
async def test_reports_missing_oauth2_with_insufficient_scopes(self):
"""resolve_block_credentials should report missing when OAuth2 scopes are insufficient."""
from backend.blocks.google.gmail import GmailSendBlock
block = GmailSendBlock()
input_data = {}
# GmailSendBlock requires gmail.send scope; provide only readonly
insufficient_cred = OAuth2Credentials(
id="limited-oauth-id",
provider="google",
access_token=SecretStr("mock-token"),
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
title="Google OAuth (limited)",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[insufficient_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert len(matched) == 0
assert len(missing) == 1
async def test_matches_user_password_credential_for_email_block(self):
"""resolve_block_credentials should match a user/password cred for an SMTP block."""
from backend.blocks.email_block import SendEmailBlock
block = SendEmailBlock()
input_data = {}
mock_cred = UserPasswordCredentials(
id="smtp-cred-id",
provider="smtp",
username=SecretStr("test-user"),
password=SecretStr("test-pass"),
title="SMTP Credentials",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[mock_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert "credentials" in matched
assert matched["credentials"].id == "smtp-cred-id"
assert len(missing) == 0
async def test_reports_missing_user_password_for_wrong_provider(self):
"""resolve_block_credentials should report missing when user/password provider mismatches."""
from backend.blocks.email_block import SendEmailBlock
block = SendEmailBlock()
input_data = {}
wrong_cred = UserPasswordCredentials(
id="wrong-cred-id",
provider="dataforseo",
username=SecretStr("user"),
password=SecretStr("pass"),
title="DataForSEO Creds",
)
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[wrong_cred],
):
matched, missing = await resolve_block_credentials(
_TEST_USER_ID, block, input_data
)
assert len(matched) == 0
assert len(missing) == 1
# ---------------------------------------------------------------------------
# RunBlockTool integration tests for authenticated HTTP
# ---------------------------------------------------------------------------
class TestRunBlockToolAuthenticatedHttp:
"""End-to-end tests for RunBlockTool with SendAuthenticatedWebRequestBlock."""
async def test_returns_setup_requirements_when_creds_missing(self):
"""When no matching host-scoped credential exists, return SetupRequirementsResponse."""
session = make_session(user_id=_TEST_USER_ID)
block = SendAuthenticatedWebRequestBlock()
with patch(
"backend.copilot.tools.helpers.get_block",
return_value=block,
):
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[],
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id=block.id,
input_data={"url": "https://api.example.com/data", "method": "GET"},
)
assert isinstance(response, SetupRequirementsResponse)
assert "credentials" in response.message.lower()
async def test_returns_details_when_creds_matched_but_missing_required_inputs(self):
"""When creds present + required inputs missing -> BlockDetailsResponse.
Note: with input_data={}, no URL is provided so discriminator_values is
empty, meaning _credential_is_for_host() matches any host-scoped
credential vacuously. This test exercises the "creds present + inputs
missing" branch, not host-based matching (which is covered by
TestFindMatchingHostScopedCredential and TestResolveBlockCredentials).
"""
session = make_session(user_id=_TEST_USER_ID)
block = SendAuthenticatedWebRequestBlock()
mock_cred = HostScopedCredentials(
id="matching-cred-id",
provider="http",
host="api.example.com",
headers={"Authorization": SecretStr("Bearer token")},
title="Example API Cred",
)
with patch(
"backend.copilot.tools.helpers.get_block",
return_value=block,
):
with patch(
"backend.copilot.tools.utils.get_user_credentials",
new_callable=AsyncMock,
return_value=[mock_cred],
):
tool = RunBlockTool()
# Call with empty input to get schema
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id=block.id,
input_data={},
)
assert isinstance(response, BlockDetailsResponse)
assert response.block.name == block.name
# The matched credential should be included in the details
assert len(response.block.credentials) > 0
assert response.block.credentials[0].id == "matching-cred-id"

View File

@@ -272,6 +272,7 @@ class ExecutionOutputInfo(BaseModel):
ended_at: datetime | None = None
outputs: dict[str, list[Any]]
inputs_summary: dict[str, Any] | None = None
node_executions: list[dict[str, Any]] | None = None
class AgentOutputResponse(ToolResponseBase):
@@ -457,6 +458,7 @@ class BlockOutputResponse(ToolResponseBase):
block_name: str
outputs: dict[str, list[Any]]
success: bool = True
is_dry_run: bool = False
class ReviewRequiredResponse(ToolResponseBase):

View File

@@ -71,6 +71,7 @@ class RunAgentInput(BaseModel):
cron: str = ""
timezone: str = "UTC"
wait_for_result: int = Field(default=0, ge=0, le=300)
dry_run: bool = False
@field_validator(
"username_agent_slug",
@@ -150,6 +151,14 @@ class RunAgentTool(BaseTool):
"minimum": 0,
"maximum": 300,
},
"dry_run": {
"type": "boolean",
"description": (
"When true, simulates the entire agent execution using an LLM "
"for each block — no real API calls, no credentials needed, "
"no credits charged. Useful for testing agent wiring end-to-end."
),
},
},
"required": [],
}
@@ -229,103 +238,17 @@ class RunAgentTool(BaseTool):
session_id=session_id,
)
# Step 2: Check credentials
graph_credentials, missing_creds = await match_user_credentials_to_graph(
user_id, graph
# Step 2: Check credentials and inputs
graph_credentials, prereq_error = await self._check_prerequisites(
graph=graph,
user_id=user_id,
params=params,
session_id=session_id,
)
if prereq_error:
return prereq_error
if missing_creds:
# Return credentials needed response with input data info
# The UI handles credential setup automatically, so the message
# focuses on asking about input data
requirements_creds_dict = build_missing_credentials_from_graph(
graph, None
)
missing_credentials_dict = build_missing_credentials_from_graph(
graph, graph_credentials
)
requirements_creds_list = list(requirements_creds_dict.values())
return SetupRequirementsResponse(
message=self._build_inputs_message(graph, MSG_WHAT_VALUES_TO_USE),
session_id=session_id,
setup_info=SetupInfo(
agent_id=graph.id,
agent_name=graph.name,
user_readiness=UserReadiness(
has_all_credentials=False,
missing_credentials=missing_credentials_dict,
ready_to_run=False,
),
requirements={
"credentials": requirements_creds_list,
"inputs": get_inputs_from_schema(graph.input_schema),
"execution_modes": self._get_execution_modes(graph),
},
),
graph_id=graph.id,
graph_version=graph.version,
)
# Step 3: Check inputs
# Get all available input fields from schema
input_properties = graph.input_schema.get("properties", {})
required_fields = set(graph.input_schema.get("required", []))
provided_inputs = set(params.inputs.keys())
valid_fields = set(input_properties.keys())
# Check for unknown input fields
unrecognized_fields = provided_inputs - valid_fields
if unrecognized_fields:
return InputValidationErrorResponse(
message=(
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
f"Agent was not executed. Please use the correct field names from the schema."
),
session_id=session_id,
unrecognized_fields=sorted(unrecognized_fields),
inputs=graph.input_schema,
graph_id=graph.id,
graph_version=graph.version,
)
# If agent has inputs but none were provided AND use_defaults is not set,
# always show what's available first so user can decide
if input_properties and not provided_inputs and not params.use_defaults:
credentials = extract_credentials_from_schema(
graph.credentials_input_schema
)
return AgentDetailsResponse(
message=self._build_inputs_message(graph, MSG_ASK_USER_FOR_VALUES),
session_id=session_id,
agent=self._build_agent_details(graph, credentials),
user_authenticated=True,
graph_id=graph.id,
graph_version=graph.version,
)
# Check if required inputs are missing (and not using defaults)
missing_inputs = required_fields - provided_inputs
if missing_inputs and not params.use_defaults:
# Return agent details with missing inputs info
credentials = extract_credentials_from_schema(
graph.credentials_input_schema
)
return AgentDetailsResponse(
message=(
f"Agent '{graph.name}' is missing required inputs: "
f"{', '.join(missing_inputs)}. "
"Please provide these values to run the agent."
),
session_id=session_id,
agent=self._build_agent_details(graph, credentials),
user_authenticated=True,
graph_id=graph.id,
graph_version=graph.version,
)
# Step 4: Execute or Schedule
# Step 3: Execute or Schedule
if is_schedule:
return await self._schedule_agent(
user_id=user_id,
@@ -345,6 +268,7 @@ class RunAgentTool(BaseTool):
graph_credentials=graph_credentials,
inputs=params.inputs,
wait_for_result=params.wait_for_result,
dry_run=params.dry_run,
)
except NotFoundError as e:
@@ -354,14 +278,14 @@ class RunAgentTool(BaseTool):
session_id=session_id,
)
except DatabaseError as e:
logger.error(f"Database error: {e}", exc_info=True)
logger.error("Database error: %s", e, exc_info=True)
return ErrorResponse(
message=f"Failed to process request: {e!s}",
error=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Error processing agent request: {e}", exc_info=True)
logger.error("Error processing agent request: %s", e, exc_info=True)
return ErrorResponse(
message=f"Failed to process request: {e!s}",
error=str(e),
@@ -421,6 +345,112 @@ class RunAgentTool(BaseTool):
trigger_info=trigger_info,
)
async def _check_prerequisites(
self,
graph: GraphModel,
user_id: str,
params: "RunAgentInput",
session_id: str,
) -> tuple[dict[str, CredentialsMetaInput], ToolResponseBase | None]:
"""Validate credentials and inputs before execution.
Dry runs skip all prerequisite gates (credentials, input prompts)
since simulate_block doesn't need real credentials or complete inputs.
Returns:
(graph_credentials, error_response) — error_response is None when ready.
"""
graph_credentials, missing_creds = await match_user_credentials_to_graph(
user_id, graph
)
# --- Reject unknown input fields (always, even for dry runs) ---
input_properties = graph.input_schema.get("properties", {})
provided_inputs = set(params.inputs.keys())
valid_fields = set(input_properties.keys())
unrecognized_fields = provided_inputs - valid_fields
if unrecognized_fields:
return graph_credentials, InputValidationErrorResponse(
message=(
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
f"Agent was not executed. Please use the correct field names from the schema."
),
session_id=session_id,
unrecognized_fields=sorted(unrecognized_fields),
inputs=graph.input_schema,
graph_id=graph.id,
graph_version=graph.version,
)
# Dry runs bypass remaining prerequisite gates (credentials, missing inputs)
if params.dry_run:
return graph_credentials, None
# --- Credential gate ---
if missing_creds:
requirements_creds_dict = build_missing_credentials_from_graph(graph, None)
missing_credentials_dict = build_missing_credentials_from_graph(
graph, graph_credentials
)
return graph_credentials, SetupRequirementsResponse(
message=self._build_inputs_message(graph, MSG_WHAT_VALUES_TO_USE),
session_id=session_id,
setup_info=SetupInfo(
agent_id=graph.id,
agent_name=graph.name,
user_readiness=UserReadiness(
has_all_credentials=False,
missing_credentials=missing_credentials_dict,
ready_to_run=False,
),
requirements={
"credentials": list(requirements_creds_dict.values()),
"inputs": get_inputs_from_schema(graph.input_schema),
"execution_modes": self._get_execution_modes(graph),
},
),
graph_id=graph.id,
graph_version=graph.version,
)
# --- Input gates ---
required_fields = set(graph.input_schema.get("required", []))
# Prompt user when inputs exist but none were provided
if input_properties and not provided_inputs and not params.use_defaults:
credentials = extract_credentials_from_schema(
graph.credentials_input_schema
)
return graph_credentials, AgentDetailsResponse(
message=self._build_inputs_message(graph, MSG_ASK_USER_FOR_VALUES),
session_id=session_id,
agent=self._build_agent_details(graph, credentials),
user_authenticated=True,
graph_id=graph.id,
graph_version=graph.version,
)
# Required inputs missing
missing_inputs = required_fields - provided_inputs
if missing_inputs and not params.use_defaults:
credentials = extract_credentials_from_schema(
graph.credentials_input_schema
)
return graph_credentials, AgentDetailsResponse(
message=(
f"Agent '{graph.name}' is missing required inputs: "
f"{', '.join(missing_inputs)}. "
"Please provide these values to run the agent."
),
session_id=session_id,
agent=self._build_agent_details(graph, credentials),
user_authenticated=True,
graph_id=graph.id,
graph_version=graph.version,
)
return graph_credentials, None
async def _run_agent(
self,
user_id: str,
@@ -429,12 +459,16 @@ class RunAgentTool(BaseTool):
graph_credentials: dict[str, CredentialsMetaInput],
inputs: dict[str, Any],
wait_for_result: int = 0,
dry_run: bool = False,
) -> ToolResponseBase:
"""Execute an agent immediately, optionally waiting for completion."""
session_id = session.session_id
# Check rate limits
if session.successful_agent_runs.get(graph.id, 0) >= config.max_agent_runs:
# Check rate limits (dry runs don't count against the session limit)
if (
not dry_run
and session.successful_agent_runs.get(graph.id, 0) >= config.max_agent_runs
):
return ErrorResponse(
message="Maximum agent runs reached for this session. Please try again later.",
session_id=session_id,
@@ -449,12 +483,14 @@ class RunAgentTool(BaseTool):
user_id=user_id,
inputs=inputs,
graph_credentials_inputs=graph_credentials,
dry_run=dry_run,
)
# Track successful run
session.successful_agent_runs[library_agent.graph_id] = (
session.successful_agent_runs.get(library_agent.graph_id, 0) + 1
)
# Track successful run (dry runs don't count against the session limit)
if not dry_run:
session.successful_agent_runs[library_agent.graph_id] = (
session.successful_agent_runs.get(library_agent.graph_id, 0) + 1
)
# Track in PostHog
track_agent_run_success(

View File

@@ -1,8 +1,10 @@
"""Tool for executing blocks directly."""
import logging
import uuid
from typing import Any
from backend.copilot.constants import COPILOT_NODE_EXEC_ID_SEPARATOR
from backend.copilot.context import get_current_permissions
from backend.copilot.model import ChatSession
@@ -47,6 +49,14 @@ class RunBlockTool(BaseTool):
"type": "object",
"description": "Input values. Use {} first to see schema.",
},
"dry_run": {
"type": "boolean",
"description": (
"When true, simulates block execution using an LLM without making any "
"real API calls or producing side effects. Useful for testing agent "
"wiring and previewing outputs. Default: false."
),
},
},
"required": ["block_id", "input_data"],
}
@@ -76,6 +86,7 @@ class RunBlockTool(BaseTool):
"""
block_id = kwargs.get("block_id", "").strip()
input_data = kwargs.get("input_data", {})
dry_run = bool(kwargs.get("dry_run", False))
session_id = session.session_id
if not block_id:
@@ -104,6 +115,7 @@ class RunBlockTool(BaseTool):
user_id=user_id,
session=session,
session_id=session_id,
dry_run=dry_run,
)
if isinstance(prep_or_err, ToolResponseBase):
return prep_or_err
@@ -130,6 +142,27 @@ class RunBlockTool(BaseTool):
session_id=session_id,
)
# Dry-run fast-path: skip credential/HITL checks — simulation never calls
# the real service so credentials and review gates are not needed.
# Input field validation (unrecognized fields) is already handled by
# prepare_block_for_execution above.
if dry_run:
synthetic_node_exec_id = (
f"{prep.synthetic_node_id}"
f"{COPILOT_NODE_EXEC_ID_SEPARATOR}"
f"{uuid.uuid4().hex[:8]}"
)
return await execute_block(
block=prep.block,
block_id=block_id,
input_data=prep.input_data,
user_id=user_id,
session_id=session_id,
node_exec_id=synthetic_node_exec_id,
matched_credentials=prep.matched_credentials,
dry_run=True,
)
# Show block details when required inputs are not yet provided.
# This is run_block's two-step UX: first call returns the schema,
# second call (with inputs) actually executes.
@@ -177,4 +210,5 @@ class RunBlockTool(BaseTool):
session_id=session_id,
node_exec_id=synthetic_node_exec_id,
matched_credentials=prep.matched_credentials,
dry_run=dry_run,
)

View File

@@ -0,0 +1,358 @@
"""Tests for dry-run execution mode."""
import inspect
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import backend.copilot.tools.run_block as run_block_module
from backend.copilot.tools.helpers import execute_block
from backend.copilot.tools.models import BlockOutputResponse, ErrorResponse
from backend.copilot.tools.run_block import RunBlockTool
from backend.executor.simulator import build_simulation_prompt, simulate_block
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_mock_block(
name: str = "TestBlock",
description: str = "A test block",
input_props: dict | None = None,
output_props: dict | None = None,
):
"""Create a minimal mock block with jsonschema() methods."""
block = MagicMock()
block.name = name
block.description = description
in_props = input_props or {"query": {"type": "string"}}
out_props = output_props or {
"result": {"type": "string"},
"error": {"type": "string"},
}
block.input_schema = MagicMock()
block.input_schema.jsonschema.return_value = {
"type": "object",
"properties": in_props,
"required": list(in_props.keys()),
}
block.input_schema.get_credentials_fields.return_value = {}
block.input_schema.get_credentials_fields_info.return_value = {}
block.output_schema = MagicMock()
block.output_schema.jsonschema.return_value = {
"type": "object",
"properties": out_props,
"required": ["result"],
}
return block
def make_openai_response(
content: str, prompt_tokens: int = 100, completion_tokens: int = 50
):
"""Build a mock OpenAI chat completion response."""
response = MagicMock()
response.choices = [MagicMock()]
response.choices[0].message.content = content
response.usage = MagicMock()
response.usage.prompt_tokens = prompt_tokens
response.usage.completion_tokens = completion_tokens
return response
# ---------------------------------------------------------------------------
# simulate_block tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_simulate_block_basic():
"""simulate_block returns correct (output_name, output_data) tuples."""
mock_block = make_mock_block()
mock_client = AsyncMock()
mock_client.chat.completions.create = AsyncMock(
return_value=make_openai_response('{"result": "simulated output", "error": ""}')
)
with patch(
"backend.executor.simulator.get_openai_client", return_value=mock_client
):
outputs = []
async for name, data in simulate_block(mock_block, {"query": "test"}):
outputs.append((name, data))
assert ("result", "simulated output") in outputs
assert ("error", "") in outputs
@pytest.mark.asyncio
async def test_simulate_block_json_retry():
"""LLM returns invalid JSON twice then valid; verifies 3 total calls."""
mock_block = make_mock_block()
mock_client = AsyncMock()
mock_client.chat.completions.create = AsyncMock(
side_effect=[
make_openai_response("not json at all"),
make_openai_response("still not json"),
make_openai_response('{"result": "ok", "error": ""}'),
]
)
with patch(
"backend.executor.simulator.get_openai_client", return_value=mock_client
):
outputs = []
async for name, data in simulate_block(mock_block, {"query": "test"}):
outputs.append((name, data))
assert mock_client.chat.completions.create.call_count == 3
assert ("result", "ok") in outputs
@pytest.mark.asyncio
async def test_simulate_block_all_retries_exhausted():
"""LLM always returns invalid JSON; verify error tuple is yielded."""
mock_block = make_mock_block()
mock_client = AsyncMock()
mock_client.chat.completions.create = AsyncMock(
return_value=make_openai_response("bad json !!!")
)
with patch(
"backend.executor.simulator.get_openai_client", return_value=mock_client
):
outputs = []
async for name, data in simulate_block(mock_block, {"query": "test"}):
outputs.append((name, data))
# All retry attempts should have been consumed
assert mock_client.chat.completions.create.call_count == 5 # _MAX_JSON_RETRIES
assert len(outputs) == 1
name, data = outputs[0]
assert name == "error"
assert "[SIMULATOR ERROR" in data
@pytest.mark.asyncio
async def test_simulate_block_missing_output_pins():
"""LLM response missing some output pins; verify they're filled with None."""
mock_block = make_mock_block(
output_props={
"result": {"type": "string"},
"count": {"type": "integer"},
"error": {"type": "string"},
}
)
mock_client = AsyncMock()
# Only returns "result", missing "count" and "error"
mock_client.chat.completions.create = AsyncMock(
return_value=make_openai_response('{"result": "hello"}')
)
with patch(
"backend.executor.simulator.get_openai_client", return_value=mock_client
):
outputs = {}
async for name, data in simulate_block(mock_block, {"query": "hi"}):
outputs[name] = data
assert outputs["result"] == "hello"
assert outputs["count"] is None # missing pin filled with None
assert outputs["error"] == "" # "error" pin filled with ""
@pytest.mark.asyncio
async def test_simulate_block_no_client():
"""When no OpenAI client is available, yields SIMULATOR ERROR."""
mock_block = make_mock_block()
with patch("backend.executor.simulator.get_openai_client", return_value=None):
outputs = []
async for name, data in simulate_block(mock_block, {}):
outputs.append((name, data))
assert len(outputs) == 1
name, data = outputs[0]
assert name == "error"
assert "[SIMULATOR ERROR" in data
@pytest.mark.asyncio
async def test_simulate_block_truncates_long_inputs():
"""Inputs with very long strings should be truncated in the prompt."""
mock_block = make_mock_block(input_props={"text": {"type": "string"}})
long_text = "x" * 30000 # 30k chars, above the 20k threshold
system_prompt, user_prompt = build_simulation_prompt(
mock_block, {"text": long_text}
)
# The user prompt should contain TRUNCATED marker
assert "[TRUNCATED]" in user_prompt
# And the total length of the value in the prompt should be well under 30k chars
parsed = json.loads(user_prompt.split("## Current Inputs\n", 1)[1])
assert len(parsed["text"]) < 25000
# ---------------------------------------------------------------------------
# execute_block dry-run tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_execute_block_dry_run_skips_real_execution():
"""execute_block(dry_run=True) calls simulate_block, NOT block.execute."""
mock_block = make_mock_block()
mock_block.execute = AsyncMock() # should NOT be called
async def fake_simulate(block, input_data):
yield "result", "simulated"
# Patching at helpers.simulate_block works because helpers.py imports
# simulate_block at the top of the module. If the import were lazy
# (inside the function), we'd need to patch the source module instead.
with patch(
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate
):
response = await execute_block(
block=mock_block,
block_id="test-block-id",
input_data={"query": "hello"},
user_id="user-1",
session_id="session-1",
node_exec_id="node-exec-1",
matched_credentials={},
dry_run=True,
)
mock_block.execute.assert_not_called()
assert isinstance(response, BlockOutputResponse)
assert response.success is True
@pytest.mark.asyncio
async def test_execute_block_dry_run_response_format():
"""Dry-run response should contain [DRY RUN] in message and success=True."""
mock_block = make_mock_block()
async def fake_simulate(block, input_data):
yield "result", "simulated"
with patch(
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate
):
response = await execute_block(
block=mock_block,
block_id="test-block-id",
input_data={"query": "hello"},
user_id="user-1",
session_id="session-1",
node_exec_id="node-exec-1",
matched_credentials={},
dry_run=True,
)
assert isinstance(response, BlockOutputResponse)
assert "[DRY RUN]" in response.message
assert response.success is True
assert response.outputs == {"result": ["simulated"]}
@pytest.mark.asyncio
async def test_execute_block_real_execution_unchanged():
"""dry_run=False should still go through the real execution path."""
mock_block = make_mock_block()
# We expect it to hit the real path, which will fail on workspace_db() call.
# Just verify simulate_block is NOT called.
simulate_called = False
async def fake_simulate(block, input_data):
nonlocal simulate_called
simulate_called = True
yield "result", "should not happen"
with patch(
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate
):
with patch(
"backend.copilot.tools.helpers.workspace_db",
side_effect=Exception("db not available"),
):
response = await execute_block(
block=mock_block,
block_id="test-block-id",
input_data={"query": "hello"},
user_id="user-1",
session_id="session-1",
node_exec_id="node-exec-1",
matched_credentials={},
dry_run=False,
)
assert simulate_called is False
# The real path raised an exception, so we get an ErrorResponse (which has .error attr)
assert hasattr(response, "error")
# ---------------------------------------------------------------------------
# RunBlockTool parameter tests
# ---------------------------------------------------------------------------
def test_run_block_tool_dry_run_param():
"""RunBlockTool parameters should include 'dry_run'."""
tool = RunBlockTool()
params = tool.parameters
assert "dry_run" in params["properties"]
assert params["properties"]["dry_run"]["type"] == "boolean"
def test_run_block_tool_dry_run_calls_execute():
"""RunBlockTool._execute extracts dry_run from kwargs correctly.
We verify the extraction logic directly by inspecting the source, then confirm
the kwarg is forwarded in the execute_block call site.
"""
source = inspect.getsource(run_block_module.RunBlockTool._execute)
# Verify dry_run is extracted from kwargs
assert "dry_run" in source
assert 'kwargs.get("dry_run"' in source
# Scope to _execute method source only — module-wide search is brittle
# and can match unrelated text/comments.
source_execute = inspect.getsource(run_block_module.RunBlockTool._execute)
# Verify dry_run is passed through to execute_block call
assert "dry_run=dry_run" in source_execute
@pytest.mark.asyncio
async def test_execute_block_dry_run_simulator_error_returns_error_response():
"""When simulate_block yields a SIMULATOR ERROR tuple, execute_block returns ErrorResponse."""
mock_block = make_mock_block()
async def fake_simulate_error(block, input_data):
yield "error", "[SIMULATOR ERROR — NOT A BLOCK FAILURE] No LLM client available (missing OpenAI/OpenRouter API key)."
with patch(
"backend.copilot.tools.helpers.simulate_block", side_effect=fake_simulate_error
):
response = await execute_block(
block=mock_block,
block_id="test-block-id",
input_data={"query": "hello"},
user_id="user-1",
session_id="session-1",
node_exec_id="node-exec-1",
matched_credentials={},
dry_run=True,
)
assert isinstance(response, ErrorResponse)
assert "[SIMULATOR ERROR" in response.message

View File

@@ -121,7 +121,7 @@ def _serialize_missing_credential(
provider = next(iter(field_info.provider), "unknown")
scopes = sorted(field_info.required_scopes or [])
return {
result: dict[str, Any] = {
"id": field_key,
"title": field_key.replace("_", " ").title(),
"provider": provider,
@@ -131,6 +131,17 @@ def _serialize_missing_credential(
"scopes": scopes,
}
# Include discriminator info so the frontend can auto-match
# host-scoped credentials (e.g. SendAuthenticatedWebRequestBlock).
if field_info.discriminator:
result["discriminator"] = field_info.discriminator
if field_info.discriminator_values:
result["discriminator_values"] = sorted(
str(v) for v in field_info.discriminator_values
)
return result
def build_missing_credentials_from_graph(
graph: GraphModel, matched_credentials: dict[str, CredentialsMetaInput] | None

View File

@@ -344,6 +344,7 @@ class DatabaseManager(AppService):
get_next_sequence = _(chat_db.get_next_sequence)
update_tool_message_content = _(chat_db.update_tool_message_content)
update_chat_session_title = _(chat_db.update_chat_session_title)
set_turn_duration = _(chat_db.set_turn_duration)
class DatabaseManagerClient(AppServiceClient):
@@ -540,3 +541,4 @@ class DatabaseManagerAsyncClient(AppServiceClient):
get_next_sequence = d.get_next_sequence
update_tool_message_content = d.update_tool_message_content
update_chat_session_title = d.update_chat_session_title
set_turn_duration = d.set_turn_duration

View File

@@ -89,6 +89,7 @@ class ExecutionContext(BaseModel):
# Safety settings
human_in_the_loop_safe_mode: bool = True
sensitive_action_safe_mode: bool = False
dry_run: bool = False # When True, blocks are LLM-simulated, no real execution
# User settings
user_timezone: str = "UTC"
@@ -178,6 +179,7 @@ class GraphExecutionMeta(BaseDbModel):
)
is_shared: bool = False
share_token: Optional[str] = None
is_dry_run: bool = False
class Stats(BaseModel):
model_config = ConfigDict(
@@ -306,6 +308,7 @@ class GraphExecutionMeta(BaseDbModel):
),
is_shared=_graph_exec.isShared,
share_token=_graph_exec.shareToken,
is_dry_run=stats.is_dry_run if stats else False,
)
@@ -339,6 +342,7 @@ class GraphExecution(GraphExecutionMeta):
if (
(block := get_block(exec.block_id))
and block.block_type == BlockType.INPUT
and "name" in exec.input_data
)
}
),
@@ -357,8 +361,10 @@ class GraphExecution(GraphExecutionMeta):
outputs: CompletedBlockOutput = defaultdict(list)
for exec in complete_node_executions:
if (
block := get_block(exec.block_id)
) and block.block_type == BlockType.OUTPUT:
(block := get_block(exec.block_id))
and block.block_type == BlockType.OUTPUT
and "name" in exec.input_data
):
outputs[exec.input_data["name"]].append(exec.input_data.get("value"))
return GraphExecution(
@@ -718,11 +724,12 @@ async def create_graph_execution(
graph_version: int,
starting_nodes_input: list[tuple[str, BlockInput]], # list[(node_id, BlockInput)]
inputs: Mapping[str, JsonValue],
user_id: str,
user_id: str, # Validated by callers (API auth layer / service-level checks)
preset_id: Optional[str] = None,
credential_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
parent_graph_exec_id: Optional[str] = None,
is_dry_run: bool = False,
) -> GraphExecutionWithNodes:
"""
Create a new AgentGraphExecution record.
@@ -760,6 +767,7 @@ async def create_graph_execution(
"userId": user_id,
"agentPresetId": preset_id,
"parentGraphExecutionId": parent_graph_exec_id,
**({"stats": Json({"is_dry_run": True})} if is_dry_run else {}),
},
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
)

View File

@@ -581,7 +581,6 @@ class GraphModel(Graph, GraphMeta):
field_name,
field_info,
) in node.block.input_schema.get_credentials_fields_info().items():
discriminator = field_info.discriminator
if not discriminator:
node_credential_data.append((field_info, (node.id, field_name)))
@@ -1096,6 +1095,9 @@ async def get_graph(
Retrieves a graph from the DB.
Defaults to the version with `is_active` if `version` is not passed.
See also: `get_graph_as_admin()` which bypasses ownership and marketplace
checks for admin-only routes.
Returns `None` if the record is not found.
"""
graph = None
@@ -1133,6 +1135,27 @@ async def get_graph(
):
graph = store_listing.AgentGraph
# Fall back to library membership: if the user has the agent in their
# library (non-deleted, non-archived), grant access even if the agent is
# no longer published. "You added it, you keep it."
if graph is None and user_id is not None:
library_where: dict[str, object] = {
"userId": user_id,
"agentGraphId": graph_id,
"isDeleted": False,
"isArchived": False,
}
if version is not None:
library_where["agentGraphVersion"] = version
library_agent = await LibraryAgent.prisma().find_first(
where=library_where,
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
order={"agentGraphVersion": "desc"},
)
if library_agent and library_agent.AgentGraph:
graph = library_agent.AgentGraph
if graph is None:
return None
@@ -1368,8 +1391,9 @@ async def validate_graph_execution_permissions(
## Logic
A user can execute a graph if any of these is true:
1. They own the graph and some version of it is still listed in their library
2. The graph is published in the marketplace and listed in their library
3. The graph is published in the marketplace and is being executed as a sub-agent
2. The graph is in the user's library (non-deleted, non-archived)
3. The graph is published in the marketplace and listed in their library
4. The graph is published in the marketplace and is being executed as a sub-agent
Args:
graph_id: The ID of the graph to check
@@ -1391,6 +1415,7 @@ async def validate_graph_execution_permissions(
where={
"userId": user_id,
"agentGraphId": graph_id,
"agentGraphVersion": graph_version,
"isDeleted": False,
"isArchived": False,
}
@@ -1400,19 +1425,39 @@ async def validate_graph_execution_permissions(
# Step 1: Check if user owns this graph
user_owns_graph = graph and graph.userId == user_id
# Step 2: Check if agent is in the library *and not deleted*
# Step 2: Check if the exact graph version is in the library.
user_has_in_library = library_agent is not None
owner_has_live_library_entry = user_has_in_library
if user_owns_graph and not user_has_in_library:
# Owners are allowed to execute a new version as long as some live
# library entry still exists for the graph. Non-owners stay
# version-specific.
owner_has_live_library_entry = (
await LibraryAgent.prisma().find_first(
where={
"userId": user_id,
"agentGraphId": graph_id,
"isDeleted": False,
"isArchived": False,
}
)
is not None
)
# Step 3: Apply permission logic
# Access is granted if the user owns it, it's in the marketplace, OR
# it's in the user's library ("you added it, you keep it").
if not (
user_owns_graph
or user_has_in_library
or await is_graph_published_in_marketplace(graph_id, graph_version)
):
raise GraphNotAccessibleError(
f"You do not have access to graph #{graph_id} v{graph_version}: "
"it is not owned by you and not available in the Marketplace"
"it is not owned by you, not in your library, "
"and not available in the Marketplace"
)
elif not (user_has_in_library or is_sub_graph):
elif not (user_has_in_library or owner_has_live_library_entry or is_sub_graph):
raise GraphNotInLibraryError(f"Graph #{graph_id} is not in your library")
# Step 6: Check execution-specific permissions (raises generic NotAuthorizedError)
@@ -1483,6 +1528,28 @@ async def fork_graph(graph_id: str, graph_version: int, user_id: str) -> GraphMo
async def __create_graph(tx, graph: Graph, user_id: str):
graphs = [graph] + graph.sub_graphs
# Auto-increment version for any graph entry (parent or sub-graph) whose
# (id, version) already exists. This prevents UniqueViolationError when
# the copilot re-saves an agent that already exists at the requested version.
# NOTE: This issues one find_first query per graph entry (N+1 pattern).
# Sub-graph counts are typically small (< 5), so the overhead is negligible.
for g in graphs:
existing = await AgentGraph.prisma(tx).find_first(
where={"id": g.id},
order={"version": "desc"},
)
if existing and existing.version >= g.version:
old_version = g.version
g.version = existing.version + 1
logger.warning(
"Auto-incremented graph %s version from %d to %d "
"(version %d already exists)",
g.id,
old_version,
g.version,
existing.version,
)
await AgentGraph.prisma(tx).create_many(
data=[
AgentGraphCreateInput(

View File

@@ -13,10 +13,17 @@ from backend.api.model import CreateGraph
from backend.blocks._base import BlockSchema, BlockSchemaInput
from backend.blocks.basic import StoreValueBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.data.graph import Graph, Link, Node, get_graph
from backend.data.graph import (
Graph,
Link,
Node,
get_graph,
validate_graph_execution_permissions,
)
from backend.data.model import SchemaField
from backend.data.user import DEFAULT_USER_ID
from backend.usecases.sample import create_test_user
from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError
from backend.util.test import SpinTestServer
@@ -597,9 +604,32 @@ def test_mcp_credential_combine_no_discriminator_values():
)
# --------------- get_graph access-control regression tests --------------- #
# These protect the behavior introduced in PR #11323 (Reinier, 2025-11-05):
# non-owners can access APPROVED marketplace agents but NOT pending ones.
# --------------- get_graph access-control truth table --------------- #
#
# Full matrix of access scenarios for get_graph() and get_graph_as_admin().
# Access priority: ownership > marketplace APPROVED > library membership.
# Library is version-specific. get_graph_as_admin bypasses everything.
#
# | User | Owns? | Marketplace | Library | Version | Result | Test
# |----------|-------|-------------|------------------|---------|---------|-----
# | regular | yes | any | any | v1 | ACCESS | test_get_graph_library_not_queried_when_owned
# | regular | no | APPROVED | any | v1 | ACCESS | test_get_graph_non_owner_approved_marketplace_agent
# | regular | no | not listed | active, same ver | v1 | ACCESS | test_get_graph_library_member_can_access_unpublished
# | regular | no | not listed | active, diff ver | v2 | DENIED | test_get_graph_library_wrong_version_denied
# | regular | no | not listed | deleted | v1 | DENIED | test_get_graph_deleted_library_agent_denied
# | regular | no | not listed | archived | v1 | DENIED | test_get_graph_archived_library_agent_denied
# | regular | no | not listed | not present | v1 | DENIED | test_get_graph_non_owner_pending_not_in_library_denied
# | regular | no | PENDING | active v1 | v2 | DENIED | test_library_v1_does_not_grant_access_to_pending_v2
# | regular | no | not listed | null AgentGraph | v1 | DENIED | test_get_graph_library_with_null_agent_graph_denied
# | anon | no | not listed | - | v1 | DENIED | test_get_graph_library_fallback_not_used_for_anonymous
# | anon | no | APPROVED | - | v1 | ACCESS | test_get_graph_anonymous_approved_marketplace_access
# | admin* | no | PENDING | - | v2 | ACCESS | test_admin_can_access_pending_v2_via_get_graph_as_admin
#
# Efficiency (no unnecessary queries):
# | regular | yes | - | - | v1 | no mkt/lib | test_get_graph_library_not_queried_when_owned
# | regular | no | APPROVED | - | v1 | no lib | test_get_graph_library_not_queried_when_marketplace_approved
#
# * = via get_graph_as_admin (admin-only routes)
def _make_mock_db_graph(user_id: str = "owner-user-id") -> MagicMock:
@@ -649,10 +679,9 @@ async def test_get_graph_non_owner_approved_marketplace_agent() -> None:
@pytest.mark.asyncio
async def test_get_graph_non_owner_pending_marketplace_agent_denied() -> None:
"""A non-owner must NOT be able to access a graph that only has a PENDING
(not APPROVED) marketplace listing. The marketplace fallback filters on
submissionStatus=APPROVED, so pending agents should be invisible."""
async def test_get_graph_non_owner_pending_not_in_library_denied() -> None:
"""A non-owner with no library membership and no APPROVED marketplace
listing must be denied access."""
requester_id = "different-user-id"
graph_id = "graph-id"
@@ -661,11 +690,11 @@ async def test_get_graph_non_owner_pending_marketplace_agent_denied() -> None:
patch(
"backend.data.graph.StoreListingVersion.prisma",
) as mock_slv_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
):
# First lookup (owned graph) returns None
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
# Marketplace fallback finds nothing (not APPROVED)
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
result = await get_graph(
graph_id=graph_id,
@@ -673,4 +702,761 @@ async def test_get_graph_non_owner_pending_marketplace_agent_denied() -> None:
user_id=requester_id,
)
assert result is None, "Non-owner must not access a pending marketplace agent"
assert (
result is None
), "User without ownership, marketplace, or library access must be denied"
# --------------- Library membership grants graph access --------------- #
# "You added it, you keep it" — product decision from SECRT-2167.
@pytest.mark.asyncio
async def test_get_graph_library_member_can_access_unpublished() -> None:
"""A user who has the agent in their library should be able to access it
even if it's no longer published in the marketplace."""
requester_id = "library-user-id"
graph_id = "graph-id"
mock_graph = _make_mock_db_graph("original-creator-id")
mock_graph_model = MagicMock(name="GraphModel")
mock_library_agent = MagicMock()
mock_library_agent.AgentGraph = mock_graph
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch(
"backend.data.graph.StoreListingVersion.prisma",
) as mock_slv_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
patch(
"backend.data.graph.GraphModel.from_db",
return_value=mock_graph_model,
),
):
# Not owned
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
# Not in marketplace (unpublished)
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
# But IS in user's library
mock_lib_prisma.return_value.find_first = AsyncMock(
return_value=mock_library_agent
)
result = await get_graph(
graph_id=graph_id,
version=1,
user_id=requester_id,
)
assert result is mock_graph_model, "Library member should access unpublished agent"
# Verify library query filters on non-deleted, non-archived
lib_call = mock_lib_prisma.return_value.find_first
lib_call.assert_awaited_once()
assert lib_call.await_args is not None
lib_where = lib_call.await_args.kwargs["where"]
assert lib_where["userId"] == requester_id
assert lib_where["agentGraphId"] == graph_id
assert lib_where["isDeleted"] is False
assert lib_where["isArchived"] is False
@pytest.mark.asyncio
async def test_get_graph_deleted_library_agent_denied() -> None:
"""If the user soft-deleted the agent from their library, they should
NOT get access via the library fallback."""
requester_id = "library-user-id"
graph_id = "graph-id"
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch(
"backend.data.graph.StoreListingVersion.prisma",
) as mock_slv_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
):
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
# Library query returns None because isDeleted=False filter excludes it
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
result = await get_graph(
graph_id=graph_id,
version=1,
user_id=requester_id,
)
assert result is None, "Deleted library agent should not grant graph access"
@pytest.mark.asyncio
async def test_get_graph_anonymous_approved_marketplace_access() -> None:
"""Anonymous users (user_id=None) should still access APPROVED marketplace
agents — the marketplace fallback doesn't require authentication."""
graph_id = "graph-id"
mock_graph = _make_mock_db_graph("creator-id")
mock_graph_model = MagicMock(name="GraphModel")
mock_listing = MagicMock()
mock_listing.AgentGraph = mock_graph
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch(
"backend.data.graph.StoreListingVersion.prisma",
) as mock_slv_prisma,
patch(
"backend.data.graph.GraphModel.from_db",
return_value=mock_graph_model,
),
):
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=mock_listing)
result = await get_graph(
graph_id=graph_id,
version=1,
user_id=None,
)
assert (
result is mock_graph_model
), "Anonymous user should access APPROVED marketplace agent"
@pytest.mark.asyncio
async def test_get_graph_library_fallback_not_used_for_anonymous() -> None:
"""Anonymous requests (user_id=None) must not trigger the library
fallback — there's no user to check library membership for."""
graph_id = "graph-id"
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch(
"backend.data.graph.StoreListingVersion.prisma",
) as mock_slv_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
):
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
result = await get_graph(
graph_id=graph_id,
version=1,
user_id=None,
)
assert result is None
# Library should never be queried for anonymous users
mock_lib_prisma.return_value.find_first.assert_not_called()
@pytest.mark.asyncio
async def test_get_graph_library_not_queried_when_owned() -> None:
"""If the user owns the graph, the library fallback should NOT be
triggered — ownership is sufficient."""
owner_id = "owner-user-id"
graph_id = "graph-id"
mock_graph = _make_mock_db_graph(owner_id)
mock_graph_model = MagicMock(name="GraphModel")
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch(
"backend.data.graph.StoreListingVersion.prisma",
) as mock_slv_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
patch(
"backend.data.graph.GraphModel.from_db",
return_value=mock_graph_model,
),
):
# User owns the graph — first lookup succeeds
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
result = await get_graph(
graph_id=graph_id,
version=1,
user_id=owner_id,
)
assert result is mock_graph_model
# Neither marketplace nor library should be queried
mock_slv_prisma.return_value.find_first.assert_not_called()
mock_lib_prisma.return_value.find_first.assert_not_called()
@pytest.mark.asyncio
async def test_get_graph_library_not_queried_when_marketplace_approved() -> None:
"""If the graph is APPROVED in the marketplace, the library fallback
should NOT be triggered — marketplace access is sufficient."""
requester_id = "different-user-id"
graph_id = "graph-id"
mock_graph = _make_mock_db_graph("original-creator-id")
mock_graph_model = MagicMock(name="GraphModel")
mock_listing = MagicMock()
mock_listing.AgentGraph = mock_graph
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch(
"backend.data.graph.StoreListingVersion.prisma",
) as mock_slv_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
patch(
"backend.data.graph.GraphModel.from_db",
return_value=mock_graph_model,
),
):
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=mock_listing)
result = await get_graph(
graph_id=graph_id,
version=1,
user_id=requester_id,
)
assert result is mock_graph_model
# Library should not be queried — marketplace was sufficient
mock_lib_prisma.return_value.find_first.assert_not_called()
@pytest.mark.asyncio
async def test_get_graph_archived_library_agent_denied() -> None:
"""If the user archived the agent in their library, they should
NOT get access via the library fallback."""
requester_id = "library-user-id"
graph_id = "graph-id"
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch(
"backend.data.graph.StoreListingVersion.prisma",
) as mock_slv_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
):
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
# Library query returns None because isArchived=False filter excludes it
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
result = await get_graph(
graph_id=graph_id,
version=1,
user_id=requester_id,
)
assert result is None, "Archived library agent should not grant graph access"
@pytest.mark.asyncio
async def test_get_graph_library_with_null_agent_graph_denied() -> None:
"""If LibraryAgent exists but its AgentGraph relation is None
(data integrity issue), access must be denied, not crash."""
requester_id = "library-user-id"
graph_id = "graph-id"
mock_library_agent = MagicMock()
mock_library_agent.AgentGraph = None # broken relation
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch(
"backend.data.graph.StoreListingVersion.prisma",
) as mock_slv_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
):
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
mock_lib_prisma.return_value.find_first = AsyncMock(
return_value=mock_library_agent
)
result = await get_graph(
graph_id=graph_id,
version=1,
user_id=requester_id,
)
assert (
result is None
), "Library agent with missing graph relation should not grant access"
@pytest.mark.asyncio
async def test_get_graph_library_wrong_version_denied() -> None:
"""Having version 1 in your library must NOT grant access to version 2."""
requester_id = "library-user-id"
graph_id = "graph-id"
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch(
"backend.data.graph.StoreListingVersion.prisma",
) as mock_slv_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
):
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
# Library has version 1 but we're requesting version 2 —
# the where clause includes agentGraphVersion so this returns None
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
result = await get_graph(
graph_id=graph_id,
version=2,
user_id=requester_id,
)
assert (
result is None
), "Library agent for version 1 must not grant access to version 2"
# Verify version was included in the library query
lib_call = mock_lib_prisma.return_value.find_first
lib_call.assert_called_once()
lib_where = lib_call.call_args.kwargs["where"]
assert lib_where["agentGraphVersion"] == 2
@pytest.mark.asyncio
async def test_library_v1_does_not_grant_access_to_pending_v2() -> None:
"""A regular user has v1 in their library. v2 is pending (not approved).
They must NOT get access to v2 — library membership is version-specific."""
requester_id = "regular-user-id"
graph_id = "graph-id"
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch(
"backend.data.graph.StoreListingVersion.prisma",
) as mock_slv_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
):
# Not owned
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
# v2 is not APPROVED in marketplace
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
# Library has v1 but not v2 — version filter excludes it
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
result = await get_graph(
graph_id=graph_id,
version=2,
user_id=requester_id,
)
assert result is None, "Regular user with v1 in library must not access pending v2"
@pytest.mark.asyncio
async def test_admin_can_access_pending_v2_via_get_graph_as_admin() -> None:
"""An admin can access v2 (pending) via get_graph_as_admin even though
only v1 is approved. get_graph_as_admin bypasses all access checks."""
from backend.data.graph import get_graph_as_admin
admin_id = "admin-user-id"
mock_graph = _make_mock_db_graph("creator-user-id")
mock_graph.version = 2
mock_graph_model = MagicMock(name="GraphModel")
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_prisma,
patch(
"backend.data.graph.GraphModel.from_db",
return_value=mock_graph_model,
),
):
mock_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
result = await get_graph_as_admin(
graph_id="graph-id",
version=2,
user_id=admin_id,
for_export=False,
)
assert (
result is mock_graph_model
), "Admin must access pending v2 via get_graph_as_admin"
# --------------- execution permission truth table --------------- #
#
# validate_graph_execution_permissions() has two gates:
# 1. Accessible graph: owner OR exact-version library entry OR marketplace-published
# 2. Runnable graph: exact-version library entry OR owner fallback to any live
# library entry for the graph OR sub-graph exception
#
# Desired owner behavior differs from non-owners:
# owners should be allowed to run a new version when some non-archived/non-deleted
# version of that graph is still in their library. Non-owners stay
# version-specific.
#
# | User | Owns? | Marketplace | Library state | is_sub_graph | Result | Test
# |----------|-------|-------------|------------------------------|--------------|----------|-----
# | regular | no | no | exact version present | false | ALLOW | test_validate_graph_execution_permissions_library_member_same_version_allowed
# | owner | yes | no | exact version present | false | ALLOW | test_validate_graph_execution_permissions_owner_same_version_in_library_allowed
# | owner | yes | no | previous version present | false | ALLOW | test_validate_graph_execution_permissions_owner_previous_library_version_allowed
# | owner | yes | no | none present | false | DENY lib | test_validate_graph_execution_permissions_owner_without_library_denied
# | owner | yes | no | only archived/deleted older | false | DENY lib | test_validate_graph_execution_permissions_owner_previous_archived_library_version_denied
# | regular | no | yes | none present | false | DENY lib | test_validate_graph_execution_permissions_marketplace_graph_not_in_library_denied
# | admin | no | no | none present | false | DENY acc | test_validate_graph_execution_permissions_admin_without_library_or_marketplace_denied
# | regular | no | yes | none present | true | ALLOW | test_validate_graph_execution_permissions_marketplace_sub_graph_without_library_allowed
# | regular | no | no | none present | true | DENY acc | test_validate_graph_execution_permissions_unpublished_sub_graph_without_library_denied
# | regular | no | no | wrong version only | false | DENY acc | test_validate_graph_execution_permissions_library_wrong_version_denied
@pytest.mark.asyncio
async def test_validate_graph_execution_permissions_library_member_same_version_allowed() -> (
None
):
requester_id = "library-user-id"
graph_id = "graph-id"
graph_version = 2
mock_graph = MagicMock(userId="creator-user-id")
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
patch(
"backend.data.graph.is_graph_published_in_marketplace",
new_callable=AsyncMock,
return_value=False,
) as mock_is_published,
):
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=MagicMock())
await validate_graph_execution_permissions(
user_id=requester_id,
graph_id=graph_id,
graph_version=graph_version,
)
mock_is_published.assert_not_awaited()
lib_where = mock_lib_prisma.return_value.find_first.call_args.kwargs["where"]
assert lib_where["agentGraphVersion"] == graph_version
@pytest.mark.asyncio
async def test_validate_graph_execution_permissions_owner_same_version_in_library_allowed() -> (
None
):
requester_id = "owner-user-id"
graph_id = "graph-id"
graph_version = 2
mock_graph = MagicMock(userId=requester_id)
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
patch(
"backend.data.graph.is_graph_published_in_marketplace",
new_callable=AsyncMock,
return_value=False,
) as mock_is_published,
):
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=MagicMock())
await validate_graph_execution_permissions(
user_id=requester_id,
graph_id=graph_id,
graph_version=graph_version,
)
mock_is_published.assert_not_awaited()
lib_where = mock_lib_prisma.return_value.find_first.call_args.kwargs["where"]
assert lib_where["agentGraphVersion"] == graph_version
@pytest.mark.asyncio
async def test_validate_graph_execution_permissions_owner_previous_library_version_allowed() -> (
None
):
requester_id = "owner-user-id"
graph_id = "graph-id"
graph_version = 2
mock_graph = MagicMock(userId=requester_id)
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
patch(
"backend.data.graph.is_graph_published_in_marketplace",
new_callable=AsyncMock,
return_value=False,
) as mock_is_published,
):
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
mock_lib_prisma.return_value.find_first = AsyncMock(
side_effect=[None, MagicMock(name="PriorVersionLibraryAgent")]
)
await validate_graph_execution_permissions(
user_id=requester_id,
graph_id=graph_id,
graph_version=graph_version,
)
mock_is_published.assert_not_awaited()
assert mock_lib_prisma.return_value.find_first.await_count == 2
first_where = mock_lib_prisma.return_value.find_first.await_args_list[0].kwargs[
"where"
]
second_where = mock_lib_prisma.return_value.find_first.await_args_list[1].kwargs[
"where"
]
assert first_where["agentGraphVersion"] == graph_version
assert "agentGraphVersion" not in second_where
@pytest.mark.asyncio
async def test_validate_graph_execution_permissions_owner_without_library_denied() -> (
None
):
requester_id = "owner-user-id"
graph_id = "graph-id"
graph_version = 2
mock_graph = MagicMock(userId=requester_id)
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
patch(
"backend.data.graph.is_graph_published_in_marketplace",
new_callable=AsyncMock,
return_value=False,
) as mock_is_published,
):
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
with pytest.raises(GraphNotInLibraryError):
await validate_graph_execution_permissions(
user_id=requester_id,
graph_id=graph_id,
graph_version=graph_version,
)
mock_is_published.assert_not_awaited()
assert mock_lib_prisma.return_value.find_first.await_count == 2
first_where = mock_lib_prisma.return_value.find_first.await_args_list[0].kwargs[
"where"
]
second_where = mock_lib_prisma.return_value.find_first.await_args_list[1].kwargs[
"where"
]
assert first_where["agentGraphVersion"] == graph_version
assert second_where == {
"userId": requester_id,
"agentGraphId": graph_id,
"isDeleted": False,
"isArchived": False,
}
@pytest.mark.asyncio
async def test_validate_graph_execution_permissions_owner_previous_archived_library_version_denied() -> (
None
):
requester_id = "owner-user-id"
graph_id = "graph-id"
graph_version = 2
mock_graph = MagicMock(userId=requester_id)
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
patch(
"backend.data.graph.is_graph_published_in_marketplace",
new_callable=AsyncMock,
return_value=False,
) as mock_is_published,
):
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
mock_lib_prisma.return_value.find_first = AsyncMock(side_effect=[None, None])
with pytest.raises(GraphNotInLibraryError):
await validate_graph_execution_permissions(
user_id=requester_id,
graph_id=graph_id,
graph_version=graph_version,
)
mock_is_published.assert_not_awaited()
assert mock_lib_prisma.return_value.find_first.await_count == 2
first_where = mock_lib_prisma.return_value.find_first.await_args_list[0].kwargs[
"where"
]
second_where = mock_lib_prisma.return_value.find_first.await_args_list[1].kwargs[
"where"
]
assert first_where["agentGraphVersion"] == graph_version
assert second_where == {
"userId": requester_id,
"agentGraphId": graph_id,
"isDeleted": False,
"isArchived": False,
}
@pytest.mark.asyncio
async def test_validate_graph_execution_permissions_marketplace_graph_not_in_library_denied() -> (
None
):
requester_id = "marketplace-user-id"
graph_id = "graph-id"
graph_version = 2
mock_graph = MagicMock(userId="creator-user-id")
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
patch(
"backend.data.graph.is_graph_published_in_marketplace",
new_callable=AsyncMock,
return_value=True,
) as mock_is_published,
):
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
with pytest.raises(GraphNotInLibraryError):
await validate_graph_execution_permissions(
user_id=requester_id,
graph_id=graph_id,
graph_version=graph_version,
)
mock_is_published.assert_awaited_once_with(graph_id, graph_version)
@pytest.mark.asyncio
async def test_validate_graph_execution_permissions_admin_without_library_or_marketplace_denied() -> (
None
):
requester_id = "admin-user-id"
graph_id = "graph-id"
graph_version = 2
mock_graph = MagicMock(userId="creator-user-id")
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
patch(
"backend.data.graph.is_graph_published_in_marketplace",
new_callable=AsyncMock,
return_value=False,
) as mock_is_published,
):
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
with pytest.raises(GraphNotAccessibleError):
await validate_graph_execution_permissions(
user_id=requester_id,
graph_id=graph_id,
graph_version=graph_version,
)
mock_is_published.assert_awaited_once_with(graph_id, graph_version)
@pytest.mark.asyncio
async def test_validate_graph_execution_permissions_unpublished_sub_graph_without_library_denied() -> (
None
):
requester_id = "marketplace-user-id"
graph_id = "graph-id"
graph_version = 2
mock_graph = MagicMock(userId="creator-user-id")
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
patch(
"backend.data.graph.is_graph_published_in_marketplace",
new_callable=AsyncMock,
return_value=False,
) as mock_is_published,
):
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
with pytest.raises(GraphNotAccessibleError):
await validate_graph_execution_permissions(
user_id=requester_id,
graph_id=graph_id,
graph_version=graph_version,
is_sub_graph=True,
)
mock_is_published.assert_awaited_once_with(graph_id, graph_version)
@pytest.mark.asyncio
async def test_validate_graph_execution_permissions_marketplace_sub_graph_without_library_allowed() -> (
None
):
requester_id = "marketplace-user-id"
graph_id = "graph-id"
graph_version = 2
mock_graph = MagicMock(userId="creator-user-id")
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
patch(
"backend.data.graph.is_graph_published_in_marketplace",
new_callable=AsyncMock,
return_value=True,
) as mock_is_published,
):
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
await validate_graph_execution_permissions(
user_id=requester_id,
graph_id=graph_id,
graph_version=graph_version,
is_sub_graph=True,
)
mock_is_published.assert_awaited_once_with(graph_id, graph_version)
@pytest.mark.asyncio
async def test_validate_graph_execution_permissions_library_wrong_version_denied() -> (
None
):
requester_id = "library-user-id"
graph_id = "graph-id"
graph_version = 2
mock_graph = MagicMock(userId="creator-user-id")
with (
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
patch(
"backend.data.graph.is_graph_published_in_marketplace",
new_callable=AsyncMock,
return_value=False,
) as mock_is_published,
):
mock_ag_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph)
mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None)
with pytest.raises(GraphNotAccessibleError):
await validate_graph_execution_permissions(
user_id=requester_id,
graph_id=graph_id,
graph_version=graph_version,
)
mock_is_published.assert_awaited_once_with(graph_id, graph_version)
lib_where = mock_lib_prisma.return_value.find_first.call_args.kwargs["where"]
assert lib_where["agentGraphVersion"] == graph_version

View File

@@ -312,10 +312,21 @@ def SchemaField(
) # type: ignore
# SDK default credentials use IDs like "{provider}-default" (set in sdk/builder.py).
# They must never be exposed to users via the API.
SDK_DEFAULT_SUFFIX = "-default"
def is_sdk_default(cred_id: str) -> bool:
return cred_id.endswith(SDK_DEFAULT_SUFFIX)
class _BaseCredentials(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
provider: str
title: Optional[str] = None
is_managed: bool = False
metadata: dict[str, Any] = Field(default_factory=dict)
@field_serializer("*")
def dump_secret_strings(value: Any, _info):
@@ -335,7 +346,6 @@ class OAuth2Credentials(_BaseCredentials):
refresh_token_expires_at: Optional[int] = None
"""Unix timestamp (seconds) indicating when the refresh token expires (if at all)"""
scopes: list[str]
metadata: dict[str, Any] = Field(default_factory=dict)
def auth_header(self) -> str:
return f"Bearer {self.access_token.get_secret_value()}"
@@ -713,7 +723,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
credentials_scopes=self.required_scopes,
discriminator=self.discriminator,
discriminator_mapping=self.discriminator_mapping,
discriminator_values=self.discriminator_values,
discriminator_values=set(self.discriminator_values), # defensive copy
)
@@ -889,6 +899,10 @@ class GraphExecutionStats(BaseModel):
default=None,
description="AI-generated score (0.0-1.0) indicating how well the execution achieved its intended purpose",
)
is_dry_run: bool = Field(
default=False,
description="Whether this execution was a dry-run simulation",
)
class UserExecutionSummaryStats(BaseModel):

View File

@@ -40,6 +40,9 @@ _MAX_PAGES = 100
# LLM extraction timeout (seconds)
_LLM_TIMEOUT = 30
SUGGESTION_THEMES = ["Learn", "Create", "Automate", "Organize"]
PROMPTS_PER_THEME = 5
def _mask_email(email: str) -> str:
"""Mask an email for safe logging: 'alice@example.com' -> 'a***e@example.com'."""
@@ -332,6 +335,11 @@ Fields:
- current_software (list of strings): software/tools currently used
- existing_automation (list of strings): existing automations
- additional_notes (string): any additional context
- suggested_prompts (object with keys "Learn", "Create", "Automate", "Organize"): for each key, \
provide a list of 5 short action prompts (each under 20 words) that would help this person. \
"Learn" = questions about AutoGPT features; "Create" = content/document generation tasks; \
"Automate" = recurring workflow automation ideas; "Organize" = structuring/prioritizing tasks. \
Should be specific to their industry, role, and pain points; actionable and conversational in tone.
Form data:
"""
@@ -378,6 +386,29 @@ async def extract_business_understanding(
# Filter out null values before constructing
cleaned = {k: v for k, v in data.items() if v is not None}
# Validate suggested_prompts: themed dict, filter >20 words, cap at 5 per theme
raw_prompts = cleaned.get("suggested_prompts", {})
if isinstance(raw_prompts, dict):
themed: dict[str, list[str]] = {}
for theme in SUGGESTION_THEMES:
theme_prompts = raw_prompts.get(theme, [])
if not isinstance(theme_prompts, list):
continue
valid = [
s
for p in theme_prompts
if isinstance(p, str) and (s := p.strip()) and len(s.split()) <= 20
]
if valid:
themed[theme] = valid[:PROMPTS_PER_THEME]
if themed:
cleaned["suggested_prompts"] = themed
else:
cleaned.pop("suggested_prompts", None)
else:
cleaned.pop("suggested_prompts", None)
return BusinessUnderstandingInput(**cleaned)

View File

@@ -284,6 +284,7 @@ async def test_populate_understanding_full_flow():
],
}
mock_input = MagicMock()
mock_input.suggested_prompts = {"Learn": ["P1"], "Create": ["P2"]}
with (
patch(
@@ -397,15 +398,25 @@ def test_extraction_prompt_no_format_placeholders():
@pytest.mark.asyncio
async def test_extract_business_understanding_success():
"""Happy path: LLM returns valid JSON that maps to BusinessUnderstandingInput."""
async def test_extract_business_understanding_themed_prompts():
"""Happy path: LLM returns themed prompts as dict."""
mock_choice = MagicMock()
mock_choice.message.content = json.dumps(
{
"user_name": "Alice",
"business_name": "Acme Corp",
"industry": "Technology",
"pain_points": ["manual reporting"],
"suggested_prompts": {
"Learn": ["Learn 1", "Learn 2", "Learn 3", "Learn 4", "Learn 5"],
"Create": [
"Create 1",
"Create 2",
"Create 3",
"Create 4",
"Create 5",
],
"Automate": ["Auto 1", "Auto 2", "Auto 3", "Auto 4", "Auto 5"],
"Organize": ["Org 1", "Org 2", "Org 3", "Org 4", "Org 5"],
},
}
)
mock_response = MagicMock()
@@ -418,9 +429,42 @@ async def test_extract_business_understanding_success():
result = await extract_business_understanding("Q: Name?\nA: Alice")
assert result.user_name == "Alice"
assert result.business_name == "Acme Corp"
assert result.industry == "Technology"
assert result.pain_points == ["manual reporting"]
assert result.suggested_prompts is not None
assert len(result.suggested_prompts) == 4
assert len(result.suggested_prompts["Learn"]) == 5
@pytest.mark.asyncio
async def test_extract_themed_prompts_filters_long_and_unknown_keys():
"""Long prompts are filtered, unknown keys are dropped, each theme capped at 5."""
long_prompt = " ".join(["word"] * 21)
mock_choice = MagicMock()
mock_choice.message.content = json.dumps(
{
"user_name": "Alice",
"suggested_prompts": {
"Learn": [long_prompt, "Valid learn 1", "Valid learn 2"],
"UnknownTheme": ["Should be dropped"],
"Automate": ["A1", "A2", "A3", "A4", "A5", "A6"],
},
}
)
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_client = AsyncMock()
mock_client.chat.completions.create.return_value = mock_response
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
result = await extract_business_understanding("Q: Name?\nA: Alice")
assert result.suggested_prompts is not None
# Unknown key dropped
assert "UnknownTheme" not in result.suggested_prompts
# Long prompt filtered
assert result.suggested_prompts["Learn"] == ["Valid learn 1", "Valid learn 2"]
# Capped at 5
assert result.suggested_prompts["Automate"] == ["A1", "A2", "A3", "A4", "A5"]
@pytest.mark.asyncio

View File

@@ -49,6 +49,25 @@ def _json_to_list(value: Any) -> list[str]:
return []
def _json_to_themed_prompts(value: Any) -> dict[str, list[str]]:
"""Convert Json field to themed prompts dict.
Handles both the new ``dict[str, list[str]]`` format and the legacy
``list[str]`` format. Legacy rows are placed under a ``"General"`` key so
existing personalised prompts remain readable until a backfill regenerates
them into the proper themed shape.
"""
if isinstance(value, dict):
return {
k: [i for i in v if isinstance(i, str)]
for k, v in value.items()
if isinstance(k, str) and isinstance(v, list)
}
if isinstance(value, list) and value:
return {"General": [str(p) for p in value if isinstance(p, str)]}
return {}
class BusinessUnderstandingInput(pydantic.BaseModel):
"""Input model for updating business understanding - all fields optional for incremental updates."""
@@ -104,6 +123,11 @@ class BusinessUnderstandingInput(pydantic.BaseModel):
None, description="Any additional context"
)
# Suggested prompts (UI-only, not included in system prompt)
suggested_prompts: Optional[dict[str, list[str]]] = pydantic.Field(
None, description="LLM-generated suggested prompts grouped by theme"
)
class BusinessUnderstanding(pydantic.BaseModel):
"""Full business understanding model returned from database."""
@@ -140,6 +164,9 @@ class BusinessUnderstanding(pydantic.BaseModel):
# Additional context
additional_notes: Optional[str] = None
# Suggested prompts (UI-only, not included in system prompt)
suggested_prompts: dict[str, list[str]] = pydantic.Field(default_factory=dict)
@classmethod
def from_db(cls, db_record: CoPilotUnderstanding) -> "BusinessUnderstanding":
"""Convert database record to Pydantic model."""
@@ -167,6 +194,7 @@ class BusinessUnderstanding(pydantic.BaseModel):
current_software=_json_to_list(business.get("current_software")),
existing_automation=_json_to_list(business.get("existing_automation")),
additional_notes=business.get("additional_notes"),
suggested_prompts=_json_to_themed_prompts(data.get("suggested_prompts")),
)
@@ -246,33 +274,22 @@ async def get_business_understanding(
return understanding
async def upsert_business_understanding(
user_id: str,
def merge_business_understanding_data(
existing_data: dict[str, Any],
input_data: BusinessUnderstandingInput,
) -> BusinessUnderstanding:
"""
Create or update business understanding with incremental merge strategy.
) -> dict[str, Any]:
"""Merge new input into existing data dict using incremental strategy.
- String fields: new value overwrites if provided (not None)
- List fields: new items are appended to existing (deduplicated)
- suggested_prompts: fully replaced if provided (not None)
Data is stored as: {name: ..., business: {version: 1, ...}}
Returns the merged data dict (mutates and returns *existing_data*).
"""
# Get existing record for merge
existing = await CoPilotUnderstanding.prisma().find_unique(
where={"userId": user_id}
)
# Get existing data structure or start fresh
existing_data: dict[str, Any] = {}
if existing and isinstance(existing.data, dict):
existing_data = dict(existing.data)
existing_business: dict[str, Any] = {}
if isinstance(existing_data.get("business"), dict):
existing_business = dict(existing_data["business"])
# Business fields (stored inside business object)
business_string_fields = [
"job_title",
"business_name",
@@ -310,16 +327,48 @@ async def upsert_business_understanding(
merged = _merge_lists(existing_list, value)
existing_business[field] = merged
# Suggested prompts - fully replace if provided
if input_data.suggested_prompts is not None:
existing_data["suggested_prompts"] = input_data.suggested_prompts
# Set version and nest business data
existing_business["version"] = 1
existing_data["business"] = existing_business
return existing_data
async def upsert_business_understanding(
user_id: str,
input_data: BusinessUnderstandingInput,
) -> BusinessUnderstanding:
"""
Create or update business understanding with incremental merge strategy.
- String fields: new value overwrites if provided (not None)
- List fields: new items are appended to existing (deduplicated)
- suggested_prompts: fully replaced if provided (not None)
Data is stored as: {name: ..., business: {version: 1, ...}}
"""
# Get existing record for merge
existing = await CoPilotUnderstanding.prisma().find_unique(
where={"userId": user_id}
)
# Get existing data structure or start fresh
existing_data: dict[str, Any] = {}
if existing and isinstance(existing.data, dict):
existing_data = dict(existing.data)
merged_data = merge_business_understanding_data(existing_data, input_data)
# Upsert with the merged data
record = await CoPilotUnderstanding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "data": SafeJson(existing_data)},
"update": {"data": SafeJson(existing_data)},
"create": {"userId": user_id, "data": SafeJson(merged_data)},
"update": {"data": SafeJson(merged_data)},
},
)

View File

@@ -0,0 +1,148 @@
"""Tests for business understanding merge and format logic."""
from datetime import datetime, timezone
from typing import Any
from unittest.mock import MagicMock
from backend.data.understanding import (
BusinessUnderstanding,
BusinessUnderstandingInput,
_json_to_themed_prompts,
format_understanding_for_prompt,
merge_business_understanding_data,
)
def _make_input(**kwargs: Any) -> BusinessUnderstandingInput:
"""Create a BusinessUnderstandingInput with only the specified fields."""
return BusinessUnderstandingInput.model_validate(kwargs)
# ─── merge_business_understanding_data: themed prompts ─────────────────
def test_merge_themed_prompts_overwrites_existing():
"""New themed prompts should fully replace existing ones (not merge)."""
existing = {
"name": "Alice",
"business": {"industry": "Tech", "version": 1},
"suggested_prompts": {
"Learn": ["Old learn prompt"],
"Create": ["Old create prompt"],
},
}
new_prompts = {
"Automate": ["Schedule daily reports", "Set up email alerts"],
"Organize": ["Sort inbox by priority"],
}
input_data = _make_input(suggested_prompts=new_prompts)
result = merge_business_understanding_data(existing, input_data)
assert result["suggested_prompts"] == new_prompts
def test_merge_themed_prompts_none_preserves_existing():
"""When input has suggested_prompts=None, existing themed prompts are preserved."""
existing_prompts = {
"Learn": ["How to automate?"],
"Create": ["Build a chatbot"],
}
existing = {
"name": "Alice",
"business": {"industry": "Tech", "version": 1},
"suggested_prompts": existing_prompts,
}
input_data = _make_input(industry="Finance")
result = merge_business_understanding_data(existing, input_data)
assert result["suggested_prompts"] == existing_prompts
assert result["business"]["industry"] == "Finance"
# ─── from_db: themed prompts deserialization ───────────────────────────
def test_from_db_themed_prompts():
"""from_db correctly deserializes a themed dict for suggested_prompts."""
themed = {
"Learn": ["What can I automate?"],
"Create": ["Build a workflow"],
}
db_record = MagicMock()
db_record.id = "test-id"
db_record.userId = "user-1"
db_record.createdAt = datetime.now(tz=timezone.utc)
db_record.updatedAt = datetime.now(tz=timezone.utc)
db_record.data = {
"name": "Alice",
"business": {"industry": "Tech", "version": 1},
"suggested_prompts": themed,
}
result = BusinessUnderstanding.from_db(db_record)
assert result.suggested_prompts == themed
def test_from_db_legacy_list_prompts_preserved_under_general():
"""from_db preserves legacy list[str] prompts under a 'General' key."""
db_record = MagicMock()
db_record.id = "test-id"
db_record.userId = "user-1"
db_record.createdAt = datetime.now(tz=timezone.utc)
db_record.updatedAt = datetime.now(tz=timezone.utc)
db_record.data = {
"name": "Alice",
"business": {"industry": "Tech", "version": 1},
"suggested_prompts": ["Old prompt 1", "Old prompt 2"],
}
result = BusinessUnderstanding.from_db(db_record)
assert result.suggested_prompts == {"General": ["Old prompt 1", "Old prompt 2"]}
# ─── _json_to_themed_prompts helper ───────────────────────────────────
def test_json_to_themed_prompts_with_dict():
value = {"Learn": ["a", "b"], "Create": ["c"]}
assert _json_to_themed_prompts(value) == {"Learn": ["a", "b"], "Create": ["c"]}
def test_json_to_themed_prompts_with_list_returns_general():
assert _json_to_themed_prompts(["a", "b"]) == {"General": ["a", "b"]}
def test_json_to_themed_prompts_with_none_returns_empty():
assert _json_to_themed_prompts(None) == {}
# ─── format_understanding_for_prompt: excludes themed prompts ──────────
def test_format_understanding_excludes_themed_prompts():
"""Themed suggested_prompts are UI-only and must NOT appear in the system prompt."""
understanding = BusinessUnderstanding(
id="test-id",
user_id="user-1",
created_at=datetime.now(tz=timezone.utc),
updated_at=datetime.now(tz=timezone.utc),
user_name="Alice",
industry="Technology",
suggested_prompts={
"Learn": ["Automate reports"],
"Create": ["Set up alerts", "Track KPIs"],
},
)
formatted = format_understanding_for_prompt(understanding)
assert "Alice" in formatted
assert "Technology" in formatted
assert "suggested_prompts" not in formatted
assert "Automate reports" not in formatted
assert "Set up alerts" not in formatted
assert "Track KPIs" not in formatted

View File

@@ -3,7 +3,7 @@ import hashlib
import hmac
import logging
from datetime import datetime, timedelta
from typing import Optional, cast
from typing import TYPE_CHECKING, Optional, cast
from urllib.parse import quote_plus
from autogpt_libs.auth.models import DEFAULT_USER_ID
@@ -21,6 +21,9 @@ from backend.util.exceptions import DatabaseError
from backend.util.json import SafeJson
from backend.util.settings import Settings
if TYPE_CHECKING:
from backend.integrations.credentials_store import IntegrationCredentialsStore
logger = logging.getLogger(__name__)
settings = Settings()
@@ -453,6 +456,27 @@ async def unsubscribe_user_by_token(token: str) -> None:
raise DatabaseError(f"Failed to unsubscribe user by token {token}: {e}") from e
async def cleanup_user_managed_credentials(
user_id: str,
store: Optional["IntegrationCredentialsStore"] = None,
) -> None:
"""Revoke all externally-provisioned managed credentials for *user_id*.
Call this before deleting a user account so that external resources
(e.g. AgentMail pods, pod-scoped API keys) are properly cleaned up.
The credential rows themselves are cascade-deleted with the User row.
Pass an existing *store* for testability; when omitted a fresh instance
is created.
"""
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.integrations.managed_credentials import cleanup_managed_credentials
if store is None:
store = IntegrationCredentialsStore()
await cleanup_managed_credentials(user_id, store)
async def update_user_timezone(user_id: str, timezone: str) -> User:
"""Update a user's timezone setting."""
try:

View File

@@ -9,6 +9,7 @@ from datetime import datetime, timezone
from typing import Optional
import pydantic
from prisma.errors import UniqueViolationError
from prisma.models import UserWorkspace, UserWorkspaceFile
from prisma.types import UserWorkspaceFileWhereInput
@@ -75,8 +76,7 @@ async def get_or_create_workspace(user_id: str) -> Workspace:
"""
Get user's workspace, creating one if it doesn't exist.
Uses upsert to handle race conditions when multiple concurrent requests
attempt to create a workspace for the same user.
Uses upsert to atomically handle concurrent creation attempts.
Args:
user_id: The user's ID
@@ -84,13 +84,19 @@ async def get_or_create_workspace(user_id: str) -> Workspace:
Returns:
Workspace instance
"""
workspace = await UserWorkspace.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id},
"update": {}, # No updates needed if exists
},
)
try:
workspace = await UserWorkspace.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id},
"update": {}, # No-op update; workspace already exists
},
)
except UniqueViolationError:
# Defense-in-depth: should not happen with upsert, but handle gracefully
workspace = await UserWorkspace.prisma().find_unique(where={"userId": user_id})
if workspace is None:
raise
return Workspace.from_db(workspace)
@@ -123,6 +129,13 @@ async def create_workspace_file(
"""
Create a new workspace file record.
Raises ``UniqueViolationError`` if a record with the same
``(workspaceId, path)`` already exists. The caller
(``WorkspaceManager._persist_db_record``) relies on this to trigger
its delete-old-file-then-retry flow, which cleans up the old storage
blob before re-creating the DB record. Using ``upsert`` here would
silently overwrite ``storagePath`` and orphan the old blob in storage.
Args:
workspace_id: The workspace ID
file_id: The file ID (same as used in storage path for consistency)

View File

@@ -81,6 +81,7 @@ from backend.util.settings import Settings
from .activity_status_generator import generate_activity_status_for_execution
from .automod.manager import automod_manager
from .cluster_lock import ClusterLock
from .simulator import simulate_block
from .utils import (
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
@@ -222,7 +223,9 @@ async def execute_node(
raise ValueError(f"Block {node_block.id} is disabled and cannot be executed")
# Sanity check: validate the execution input.
input_data, error = validate_exec(node, data.inputs, resolve_input=False)
input_data, error = validate_exec(
node, data.inputs, resolve_input=False, dry_run=execution_context.dry_run
)
if input_data is None:
log_metadata.warning(f"Skip execution, input validation error: {error}")
yield "error", error
@@ -372,9 +375,12 @@ async def execute_node(
scope.set_tag(f"execution_context.{k}", v)
try:
async for output_name, output_data in node_block.execute(
input_data, **extra_exec_kwargs
):
if execution_context.dry_run:
block_iter = simulate_block(node_block, input_data)
else:
block_iter = node_block.execute(input_data, **extra_exec_kwargs)
async for output_name, output_data in block_iter:
output_data = json.to_dict(output_data)
output_size += len(json.dumps(output_data))
log_metadata.debug("Node produced output", **{output_name: output_data})
@@ -506,7 +512,9 @@ async def _enqueue_next_nodes(
next_node_input.update(node_input_mask)
# Validate the input data for the next node.
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
next_node_input, validation_msg = validate_exec(
next_node, next_node_input, dry_run=execution_context.dry_run
)
suffix = f"{next_output_name}>{next_input_name}~{next_node_exec_id}:{validation_msg}"
# Incomplete input data, skip queueing the execution.
@@ -551,7 +559,9 @@ async def _enqueue_next_nodes(
if node_input_mask:
idata.update(node_input_mask)
idata, msg = validate_exec(next_node, idata)
idata, msg = validate_exec(
next_node, idata, dry_run=execution_context.dry_run
)
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"
if not idata:
log_metadata.info(f"Enqueueing static-link skipped: {suffix}")
@@ -829,9 +839,12 @@ class ExecutionProcessor:
return
if exec_meta.stats is None:
exec_stats = GraphExecutionStats()
exec_stats = GraphExecutionStats(
is_dry_run=graph_exec.execution_context.dry_run,
)
else:
exec_stats = exec_meta.stats.to_db()
exec_stats.is_dry_run = graph_exec.execution_context.dry_run
timing_info, status = self._on_graph_execution(
graph_exec=graph_exec,
@@ -971,7 +984,10 @@ class ExecutionProcessor:
running_node_evaluation = self.running_node_evaluation
try:
if db_client.get_credits(graph_exec.user_id) <= 0:
if (
not graph_exec.execution_context.dry_run
and db_client.get_credits(graph_exec.user_id) <= 0
):
raise InsufficientBalanceError(
user_id=graph_exec.user_id,
message="You have no credits left to run an agent.",
@@ -1042,21 +1058,24 @@ class ExecutionProcessor:
f"for node {queued_node_exec.node_id}",
)
# Charge usage (may raise) ------------------------------
# Charge usage (may raise) — skipped for dry runs
try:
cost, remaining_balance = self._charge_usage(
node_exec=queued_node_exec,
execution_count=increment_execution_count(graph_exec.user_id),
)
with execution_stats_lock:
execution_stats.cost += cost
# Check if we crossed the low balance threshold
self._handle_low_balance(
db_client=db_client,
user_id=graph_exec.user_id,
current_balance=remaining_balance,
transaction_cost=cost,
)
if not graph_exec.execution_context.dry_run:
cost, remaining_balance = self._charge_usage(
node_exec=queued_node_exec,
execution_count=increment_execution_count(
graph_exec.user_id
),
)
with execution_stats_lock:
execution_stats.cost += cost
# Check if we crossed the low balance threshold
self._handle_low_balance(
db_client=db_client,
user_id=graph_exec.user_id,
current_balance=remaining_balance,
transaction_cost=cost,
)
except InsufficientBalanceError as balance_error:
error = balance_error # Set error to trigger FAILED status
node_exec_id = queued_node_exec.node_exec_id

View File

@@ -0,0 +1,218 @@
"""
LLM-powered block simulator for dry-run execution.
When dry_run=True, instead of calling the real block, this module
role-plays the block's execution using an LLM. No real API calls,
no side effects. The LLM is grounded by:
- Block name and description
- Input/output schemas (from block.input_schema.jsonschema() / output_schema.jsonschema())
- The actual input values
Inspired by https://github.com/Significant-Gravitas/agent-simulator
"""
import json
import logging
from collections.abc import AsyncIterator
from typing import Any
from backend.util.clients import get_openai_client
logger = logging.getLogger(__name__)
# Use the same fast/cheap model the copilot uses for non-primary tasks.
# Overridable via ChatConfig.title_model if ChatConfig is available.
def _simulator_model() -> str:
try:
from backend.copilot.config import ChatConfig # noqa: PLC0415
model = ChatConfig().title_model
except Exception:
model = "openai/gpt-4o-mini"
# get_openai_client() may return a direct OpenAI client (not OpenRouter).
# Direct OpenAI expects bare model names ("gpt-4o-mini"), not the
# OpenRouter-prefixed form ("openai/gpt-4o-mini"). Strip the prefix when
# the internal OpenAI key is configured (i.e. not going through OpenRouter).
try:
from backend.util.settings import Settings # noqa: PLC0415
secrets = Settings().secrets
# get_openai_client() uses the direct OpenAI client whenever
# openai_internal_api_key is set, regardless of open_router_api_key.
# Strip the provider prefix (e.g. "openai/gpt-4o-mini" → "gpt-4o-mini")
# so the model name is valid for the direct OpenAI API.
if secrets.openai_internal_api_key and "/" in model:
model = model.split("/", 1)[1]
except Exception:
pass
return model
_TEMPERATURE = 0.2
_MAX_JSON_RETRIES = 5
_MAX_INPUT_VALUE_CHARS = 20000
def _truncate_value(value: Any) -> Any:
"""Recursively truncate long strings anywhere in a value."""
if isinstance(value, str):
return (
value[:_MAX_INPUT_VALUE_CHARS] + "... [TRUNCATED]"
if len(value) > _MAX_INPUT_VALUE_CHARS
else value
)
if isinstance(value, dict):
return {k: _truncate_value(v) for k, v in value.items()}
if isinstance(value, list):
return [_truncate_value(item) for item in value]
return value
def _truncate_input_values(input_data: dict[str, Any]) -> dict[str, Any]:
"""Recursively truncate long string values so the prompt doesn't blow up."""
return {k: _truncate_value(v) for k, v in input_data.items()}
def _describe_schema_pins(schema: dict[str, Any]) -> str:
"""Format output pins as a bullet list for the prompt."""
properties = schema.get("properties", {})
required = set(schema.get("required", []))
lines = []
for pin_name, pin_schema in properties.items():
pin_type = pin_schema.get("type", "any")
req = "required" if pin_name in required else "optional"
lines.append(f"- {pin_name}: {pin_type} ({req})")
return "\n".join(lines) if lines else "(no output pins defined)"
def build_simulation_prompt(block: Any, input_data: dict[str, Any]) -> tuple[str, str]:
"""Build (system_prompt, user_prompt) for block simulation."""
input_schema = block.input_schema.jsonschema()
output_schema = block.output_schema.jsonschema()
input_pins = _describe_schema_pins(input_schema)
output_pins = _describe_schema_pins(output_schema)
output_properties = list(output_schema.get("properties", {}).keys())
block_name = getattr(block, "name", type(block).__name__)
block_description = getattr(block, "description", "No description available.")
system_prompt = f"""You are simulating the execution of a software block called "{block_name}".
## Block Description
{block_description}
## Input Schema
{input_pins}
## Output Schema (what you must return)
{output_pins}
Your task: given the current inputs, produce realistic simulated outputs for this block.
Rules:
- Respond with a single JSON object whose keys are EXACTLY the output pin names listed above.
- Assume all credentials and authentication are present and valid. Never simulate authentication failures.
- Make the simulated outputs realistic and consistent with the inputs.
- If there is an "error" pin, set it to "" (empty string) unless you are simulating a logical error.
- Do not include any extra keys beyond the output pins.
Output pin names you MUST include: {json.dumps(output_properties)}
"""
safe_inputs = _truncate_input_values(input_data)
user_prompt = f"## Current Inputs\n{json.dumps(safe_inputs, indent=2)}"
return system_prompt, user_prompt
async def simulate_block(
block: Any,
input_data: dict[str, Any],
) -> AsyncIterator[tuple[str, Any]]:
"""Simulate block execution using an LLM.
Yields (output_name, output_data) tuples matching the Block.execute() interface.
On unrecoverable failure, yields a single ("error", "[SIMULATOR ERROR ...") tuple.
"""
client = get_openai_client()
if client is None:
yield (
"error",
"[SIMULATOR ERROR — NOT A BLOCK FAILURE] No LLM client available "
"(missing OpenAI/OpenRouter API key).",
)
return
output_schema = block.output_schema.jsonschema()
output_properties: dict[str, Any] = output_schema.get("properties", {})
system_prompt, user_prompt = build_simulation_prompt(block, input_data)
model = _simulator_model()
last_error: Exception | None = None
for attempt in range(_MAX_JSON_RETRIES):
try:
response = await client.chat.completions.create(
model=model,
temperature=_TEMPERATURE,
response_format={"type": "json_object"},
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
)
if not response.choices:
raise ValueError("LLM returned empty choices array")
raw = response.choices[0].message.content or ""
parsed = json.loads(raw)
if not isinstance(parsed, dict):
raise ValueError(f"LLM returned non-object JSON: {raw[:200]}")
# Fill missing output pins with defaults
result: dict[str, Any] = {}
for pin_name in output_properties:
if pin_name in parsed:
result[pin_name] = parsed[pin_name]
else:
result[pin_name] = "" if pin_name == "error" else None
logger.debug(
"simulate_block: block=%s attempt=%d tokens=%s/%s",
getattr(block, "name", "?"),
attempt + 1,
getattr(getattr(response, "usage", None), "prompt_tokens", "?"),
getattr(getattr(response, "usage", None), "completion_tokens", "?"),
)
for pin_name, pin_value in result.items():
yield pin_name, pin_value
return
except (json.JSONDecodeError, ValueError) as e:
last_error = e
logger.warning(
"simulate_block: JSON parse error on attempt %d/%d: %s",
attempt + 1,
_MAX_JSON_RETRIES,
e,
)
except Exception as e:
last_error = e
logger.error("simulate_block: LLM call failed: %s", e, exc_info=True)
break
logger.error(
"simulate_block: all %d retries exhausted for block=%s; last_error=%s",
_MAX_JSON_RETRIES,
getattr(block, "name", "?"),
last_error,
)
yield (
"error",
f"[SIMULATOR ERROR — NOT A BLOCK FAILURE] Failed after {_MAX_JSON_RETRIES} "
f"attempts: {last_error}",
)

View File

@@ -181,6 +181,7 @@ def validate_exec(
node: Node,
data: BlockInput,
resolve_input: bool = True,
dry_run: bool = False,
) -> tuple[BlockInput | None, str]:
"""
Validate the input data for a node execution.
@@ -189,6 +190,9 @@ def validate_exec(
node: The node to execute.
data: The input data for the node execution.
resolve_input: Whether to resolve dynamic pins into dict/list/object.
dry_run: When True, credential fields are allowed to be missing — they
will be substituted with a sentinel so the node can be queued and
later executed via simulate_block.
Returns:
A tuple of the validated data and the block name.
@@ -207,6 +211,14 @@ def validate_exec(
if missing_links := schema.get_missing_links(data, node.input_links):
return None, f"{error_prefix} unpopulated links {missing_links}"
# For dry runs, supply sentinel values for any missing credential fields so
# the node can be queued — simulate_block never calls the real API anyway.
if dry_run:
cred_field_names = set(schema.get_credentials_fields().keys())
for field_name in cred_field_names:
if field_name not in data:
data = {**data, field_name: None}
# Merge input data with default values and resolve dynamic dict/list/object pins.
input_default = schema.get_input_defaults(node.input_default)
data = {**input_default, **data}
@@ -218,13 +230,21 @@ def validate_exec(
# Input data post-merge should contain all required fields from the schema.
if missing_input := schema.get_missing_input(data):
return None, f"{error_prefix} missing input {missing_input}"
if dry_run:
# In dry-run mode all missing inputs are tolerated — simulate_block()
# generates synthetic outputs without needing real input values.
pass
else:
return None, f"{error_prefix} missing input {missing_input}"
# Last validation: Validate the input values against the schema.
if error := schema.get_mismatch_error(data):
error_message = f"{error_prefix} {error}"
logger.warning(error_message)
return None, error_message
# Skip for dry runs — simulate_block doesn't use real inputs, and sentinel
# credential values (None) would fail JSON-schema type/required checks.
if not dry_run:
if error := schema.get_mismatch_error(data):
error_message = f"{error_prefix} {error}"
logger.warning(error_message)
return None, error_message
return data, node_block.name
@@ -427,6 +447,7 @@ async def _construct_starting_node_execution_input(
user_id: str,
graph_inputs: GraphInput,
nodes_input_masks: Optional[NodesInputMasks] = None,
dry_run: bool = False,
) -> tuple[list[tuple[str, BlockInput]], set[str]]:
"""
Validates and prepares the input data for executing a graph.
@@ -439,6 +460,7 @@ async def _construct_starting_node_execution_input(
user_id (str): The ID of the user executing the graph.
data (GraphInput): The input data for the graph execution.
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
dry_run: When True, skip credential validation errors (simulation needs no real creds).
Returns:
tuple[
@@ -451,6 +473,32 @@ async def _construct_starting_node_execution_input(
validation_errors, nodes_to_skip = await validate_graph_with_credentials(
graph, user_id, nodes_input_masks
)
# Dry runs simulate every block — missing credentials are irrelevant.
# Strip credential-only errors so the graph can proceed.
if dry_run and validation_errors:
def _is_credential_error(msg: str) -> bool:
"""Match errors produced by _validate_node_input_credentials."""
m = msg.lower()
return (
m == "these credentials are required"
or m.startswith("invalid credentials:")
or m.startswith("credentials not available:")
or m.startswith("unknown credentials #")
)
validation_errors = {
node_id: {
field: msg
for field, msg in errors.items()
if not _is_credential_error(msg)
}
for node_id, errors in validation_errors.items()
}
# Remove nodes that have no remaining errors
validation_errors = {
node_id: errors for node_id, errors in validation_errors.items() if errors
}
n_error_nodes = len(validation_errors)
n_errors = sum(len(errors) for errors in validation_errors.values())
if validation_errors:
@@ -494,7 +542,7 @@ async def _construct_starting_node_execution_input(
"Please use the appropriate trigger to run this agent."
)
input_data, error = validate_exec(node, input_data)
input_data, error = validate_exec(node, input_data, dry_run=dry_run)
if input_data is None:
raise ValueError(error)
else:
@@ -516,6 +564,7 @@ async def validate_and_construct_node_execution_input(
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
is_sub_graph: bool = False,
dry_run: bool = False,
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks, set[str]]:
"""
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
@@ -581,6 +630,7 @@ async def validate_and_construct_node_execution_input(
user_id=user_id,
graph_inputs=graph_inputs,
nodes_input_masks=nodes_input_masks,
dry_run=dry_run,
)
)
@@ -818,6 +868,7 @@ async def add_graph_execution(
nodes_input_masks: Optional[NodesInputMasks] = None,
execution_context: Optional[ExecutionContext] = None,
graph_exec_id: Optional[str] = None,
dry_run: bool = False,
) -> GraphExecutionWithNodes:
"""
Adds a graph execution to the queue and returns the execution entry.
@@ -882,6 +933,7 @@ async def add_graph_execution(
graph_credentials_inputs=graph_credentials_inputs,
nodes_input_masks=nodes_input_masks,
is_sub_graph=parent_exec_id is not None,
dry_run=dry_run,
)
)
@@ -895,6 +947,7 @@ async def add_graph_execution(
starting_nodes_input=starting_nodes_input,
preset_id=preset_id,
parent_graph_exec_id=parent_exec_id,
is_dry_run=dry_run,
)
logger.info(
@@ -917,6 +970,7 @@ async def add_graph_execution(
# Safety settings
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
dry_run=dry_run,
# User settings
user_timezone=(
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"

View File

@@ -425,6 +425,7 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
starting_nodes_input=starting_nodes_input,
preset_id=preset_id,
parent_graph_exec_id=None,
is_dry_run=False,
)
# Set up the graph execution mock to have properties we can extract

View File

@@ -1,5 +1,6 @@
import base64
import hashlib
import logging
import secrets
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
@@ -21,6 +22,7 @@ from backend.data.redis_client import get_redis_async
from backend.util.settings import Settings
settings = Settings()
logger = logging.getLogger(__name__)
def provider_matches(stored: str, expected: str) -> bool:
@@ -284,6 +286,7 @@ DEFAULT_CREDENTIALS = [
elevenlabs_credentials,
]
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
# Set of providers that have system credentials available
@@ -323,20 +326,45 @@ class IntegrationCredentialsStore:
return get_database_manager_async_client()
# =============== USER-MANAGED CREDENTIALS =============== #
async def _get_persisted_user_creds_unlocked(
self, user_id: str
) -> list[Credentials]:
"""Return only the persisted (user-stored) credentials — no side effects.
**Caller must already hold ``locked_user_integrations(user_id)``.**
"""
return list((await self._get_user_integrations(user_id)).credentials)
async def add_creds(self, user_id: str, credentials: Credentials) -> None:
async with await self.locked_user_integrations(user_id):
if await self.get_creds_by_id(user_id, credentials.id):
# Check system/managed IDs without triggering provisioning
if credentials.id in SYSTEM_CREDENTIAL_IDS:
raise ValueError(
f"Can not re-create existing credentials #{credentials.id} "
f"for user #{user_id}"
)
await self._set_user_integration_creds(
user_id, [*(await self.get_all_creds(user_id)), credentials]
)
persisted = await self._get_persisted_user_creds_unlocked(user_id)
if any(c.id == credentials.id for c in persisted):
raise ValueError(
f"Can not re-create existing credentials #{credentials.id} "
f"for user #{user_id}"
)
await self._set_user_integration_creds(user_id, [*persisted, credentials])
async def get_all_creds(self, user_id: str) -> list[Credentials]:
users_credentials = (await self._get_user_integrations(user_id)).credentials
all_credentials = users_credentials
"""Public entry point — acquires lock, then delegates."""
async with await self.locked_user_integrations(user_id):
return await self._get_all_creds_unlocked(user_id)
async def _get_all_creds_unlocked(self, user_id: str) -> list[Credentials]:
"""Return all credentials for *user_id*.
**Caller must already hold ``locked_user_integrations(user_id)``.**
"""
user_integrations = await self._get_user_integrations(user_id)
all_credentials = list(user_integrations.credentials)
# These will always be added
all_credentials.append(ollama_credentials)
@@ -417,13 +445,22 @@ class IntegrationCredentialsStore:
return list(set(c.provider for c in credentials))
async def update_creds(self, user_id: str, updated: Credentials) -> None:
if updated.id in SYSTEM_CREDENTIAL_IDS:
raise ValueError(
f"System credential #{updated.id} cannot be updated directly"
)
async with await self.locked_user_integrations(user_id):
current = await self.get_creds_by_id(user_id, updated.id)
persisted = await self._get_persisted_user_creds_unlocked(user_id)
current = next((c for c in persisted if c.id == updated.id), None)
if not current:
raise ValueError(
f"Credentials with ID {updated.id} "
f"for user with ID {user_id} not found"
)
if current.is_managed:
raise ValueError(
f"AutoGPT-managed credential #{updated.id} cannot be updated"
)
if type(current) is not type(updated):
raise TypeError(
f"Can not update credentials with ID {updated.id} "
@@ -443,22 +480,53 @@ class IntegrationCredentialsStore:
f"to more restrictive set of scopes {updated.scopes}"
)
# Update the credentials
# Update only persisted credentials — no side-effectful provisioning
updated_credentials_list = [
updated if c.id == updated.id else c
for c in await self.get_all_creds(user_id)
updated if c.id == updated.id else c for c in persisted
]
await self._set_user_integration_creds(user_id, updated_credentials_list)
async def delete_creds_by_id(self, user_id: str, credentials_id: str) -> None:
if credentials_id in SYSTEM_CREDENTIAL_IDS:
raise ValueError(f"System credential #{credentials_id} cannot be deleted")
async with await self.locked_user_integrations(user_id):
filtered_credentials = [
c for c in await self.get_all_creds(user_id) if c.id != credentials_id
]
persisted = await self._get_persisted_user_creds_unlocked(user_id)
target = next((c for c in persisted if c.id == credentials_id), None)
if target and target.is_managed:
raise ValueError(
f"AutoGPT-managed credential #{credentials_id} cannot be deleted"
)
filtered_credentials = [c for c in persisted if c.id != credentials_id]
await self._set_user_integration_creds(user_id, filtered_credentials)
# ============== SYSTEM-MANAGED CREDENTIALS ============== #
async def has_managed_credential(self, user_id: str, provider: str) -> bool:
"""Check if a managed credential exists for *provider*."""
user_integrations = await self._get_user_integrations(user_id)
return any(
c.provider == provider and c.is_managed
for c in user_integrations.credentials
)
async def add_managed_credential(
self, user_id: str, credential: Credentials
) -> None:
"""Upsert a managed credential.
Removes any existing managed credential for the same provider,
then appends the new one. The credential MUST have is_managed=True.
"""
if not credential.is_managed:
raise ValueError("credential.is_managed must be True")
async with self.edit_user_integrations(user_id) as user_integrations:
user_integrations.credentials = [
c
for c in user_integrations.credentials
if not (c.provider == credential.provider and c.is_managed)
]
user_integrations.credentials.append(credential)
async def set_ayrshare_profile_key(self, user_id: str, profile_key: str) -> None:
"""Set the Ayrshare profile key for a user.

View File

@@ -0,0 +1,188 @@
"""Generic infrastructure for system-provided, per-user managed credentials.
Managed credentials are provisioned automatically by the platform (e.g. an
AgentMail pod-scoped API key) and stored alongside regular user credentials
with ``is_managed=True``. Users cannot update or delete them.
New integrations register a :class:`ManagedCredentialProvider` at import time;
the two entry-points consumed by the rest of the application are:
* :func:`ensure_managed_credentials` fired as a background task from the
credential-listing endpoints (non-blocking).
* :func:`cleanup_managed_credentials` called during account deletion to
revoke external resources (API keys, pods, etc.).
"""
from __future__ import annotations
import asyncio
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from cachetools import TTLCache
if TYPE_CHECKING:
from backend.data.model import Credentials
from backend.integrations.credentials_store import IntegrationCredentialsStore
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Abstract provider
# ---------------------------------------------------------------------------
class ManagedCredentialProvider(ABC):
"""Base class for integrations that auto-provision per-user credentials."""
provider_name: str
"""Must match the ``provider`` field on the resulting credential."""
@abstractmethod
async def is_available(self) -> bool:
"""Return ``True`` when the org-level configuration is present."""
@abstractmethod
async def provision(self, user_id: str) -> Credentials:
"""Create external resources and return a credential.
The returned credential **must** have ``is_managed=True``.
"""
@abstractmethod
async def deprovision(self, user_id: str, credential: Credentials) -> None:
"""Revoke external resources during account deletion."""
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
_PROVIDERS: dict[str, ManagedCredentialProvider] = {}
# Users whose managed credentials have already been verified recently.
# Avoids redundant DB checks on every GET /credentials call.
# maxsize caps memory; TTL re-checks periodically (e.g. when new providers
# are added). ~100K entries ≈ 4-8 MB.
_provisioned_users: TTLCache[str, bool] = TTLCache(maxsize=100_000, ttl=3600)
def register_managed_provider(provider: ManagedCredentialProvider) -> None:
_PROVIDERS[provider.provider_name] = provider
def get_managed_provider(name: str) -> ManagedCredentialProvider | None:
return _PROVIDERS.get(name)
def get_managed_providers() -> dict[str, ManagedCredentialProvider]:
return dict(_PROVIDERS)
# ---------------------------------------------------------------------------
# Public helpers
# ---------------------------------------------------------------------------
async def _ensure_one(
user_id: str,
store: IntegrationCredentialsStore,
name: str,
provider: ManagedCredentialProvider,
) -> bool:
"""Provision a single managed credential under a distributed Redis lock.
Returns ``True`` if the credential already exists or was successfully
provisioned, ``False`` on transient failure so the caller knows not to
cache the user as fully provisioned.
"""
try:
if not await provider.is_available():
return True
# Use a distributed Redis lock so the check-then-provision operation
# is atomic across all workers, preventing duplicate external
# resource provisioning (e.g. AgentMail API keys).
locks = await store.locks()
key = (f"user:{user_id}", f"managed-provision:{name}")
async with locks.locked(key):
# Re-check under lock to avoid duplicate provisioning.
if await store.has_managed_credential(user_id, name):
return True
credential = await provider.provision(user_id)
await store.add_managed_credential(user_id, credential)
logger.info(
"Provisioned managed credential for provider=%s user=%s",
name,
user_id,
)
return True
except Exception:
logger.warning(
"Failed to provision managed credential for provider=%s user=%s",
name,
user_id,
exc_info=True,
)
return False
async def ensure_managed_credentials(
user_id: str,
store: IntegrationCredentialsStore,
) -> None:
"""Provision missing managed credentials for *user_id*.
Fired as a non-blocking background task from the credential-listing
endpoints. Failures are logged but never propagated — the user simply
will not see the managed credential until the next page load.
Skips entirely if this user has already been checked during the current
process lifetime (in-memory cache). Resets on restart — just a
performance optimisation, not a correctness guarantee.
Providers are checked concurrently via ``asyncio.gather``.
"""
if user_id in _provisioned_users:
return
results = await asyncio.gather(
*(_ensure_one(user_id, store, n, p) for n, p in _PROVIDERS.items())
)
# Only cache the user as provisioned when every provider succeeded or
# was already present. A transient failure (network timeout, Redis
# blip) returns False, so the next page load will retry.
if all(results):
_provisioned_users[user_id] = True
async def cleanup_managed_credentials(
user_id: str,
store: IntegrationCredentialsStore,
) -> None:
"""Revoke all external managed resources for a user being deleted."""
all_creds = await store.get_all_creds(user_id)
managed = [c for c in all_creds if c.is_managed]
for cred in managed:
provider = _PROVIDERS.get(cred.provider)
if not provider:
logger.warning(
"No managed provider registered for %s — skipping cleanup",
cred.provider,
)
continue
try:
await provider.deprovision(user_id, cred)
logger.info(
"Deprovisioned managed credential for provider=%s user=%s",
cred.provider,
user_id,
)
except Exception:
logger.error(
"Failed to deprovision %s for user %s",
cred.provider,
user_id,
exc_info=True,
)

View File

@@ -0,0 +1,17 @@
"""Managed credential providers.
Call :func:`register_all` at application startup (e.g. in ``rest_api.py``)
to populate the provider registry before any requests are processed.
"""
from backend.integrations.managed_credentials import (
get_managed_provider,
register_managed_provider,
)
from backend.integrations.managed_providers.agentmail import AgentMailManagedProvider
def register_all() -> None:
"""Register every built-in managed credential provider (idempotent)."""
if get_managed_provider(AgentMailManagedProvider.provider_name) is None:
register_managed_provider(AgentMailManagedProvider())

View File

@@ -0,0 +1,90 @@
"""AgentMail managed credential provider.
Uses the org-level AgentMail API key to create a per-user pod and a
pod-scoped API key. The pod key is stored as an ``is_managed``
credential so it appears automatically in block credential dropdowns.
"""
from __future__ import annotations
import logging
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, Credentials
from backend.integrations.managed_credentials import ManagedCredentialProvider
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
class AgentMailManagedProvider(ManagedCredentialProvider):
provider_name = "agent_mail"
async def is_available(self) -> bool:
return bool(settings.secrets.agentmail_api_key)
async def provision(self, user_id: str) -> Credentials:
from agentmail import AsyncAgentMail
client = AsyncAgentMail(api_key=settings.secrets.agentmail_api_key)
# client_id makes pod creation idempotent — if a pod already exists
# for this user_id the SDK returns the existing pod.
pod = await client.pods.create(client_id=user_id, name=f"{user_id}-pod")
# NOTE: api_keys.create() is NOT idempotent. If the caller retries
# after a partial failure (pod created, key created, but store write
# failed), a second key will be created and the first becomes orphaned
# on AgentMail's side. The double-check pattern in _ensure_one
# (has_managed_credential under lock) prevents this in normal flow;
# only a crash between key creation and store write can cause it.
api_key_obj = await client.pods.api_keys.create(
pod_id=pod.pod_id, name=f"{user_id}-agpt-managed"
)
return APIKeyCredentials(
provider=self.provider_name,
title="AgentMail (managed by AutoGPT)",
api_key=SecretStr(api_key_obj.api_key),
expires_at=None,
is_managed=True,
metadata={"pod_id": pod.pod_id},
)
async def deprovision(self, user_id: str, credential: Credentials) -> None:
from agentmail import AsyncAgentMail
pod_id = credential.metadata.get("pod_id")
if not pod_id:
logger.warning(
"Managed credential for user %s has no pod_id in metadata — "
"skipping AgentMail cleanup",
user_id,
)
return
client = AsyncAgentMail(api_key=settings.secrets.agentmail_api_key)
try:
# Verify the pod actually belongs to this user before deleting,
# as a safety measure against cross-user deletion via the
# org-level API key.
pod = await client.pods.get(pod_id=pod_id)
if getattr(pod, "client_id", None) and pod.client_id != user_id:
logger.error(
"Pod %s client_id=%s does not match user %s"
"refusing to delete",
pod_id,
pod.client_id,
user_id,
)
return
await client.pods.delete(pod_id=pod_id)
except Exception:
logger.warning(
"Failed to delete AgentMail pod %s for user %s",
pod_id,
user_id,
exc_info=True,
)

Some files were not shown because too many files have changed in this diff Show More