Compare commits

..

34 Commits

Author SHA1 Message Date
Bentlybro
957ec038b8 chore: regenerate OpenAPI schema for new LLM endpoints 2026-03-16 15:44:53 +00:00
Bentlybro
9b39a662ee feat(platform): Add LLM registry public read API
Implements public GET endpoints for querying LLM models and providers - Part 3 of 6 in the incremental registry rollout.

**Endpoints:**
- GET /api/llm/models - List all models (filterable by enabled_only)
- GET /api/llm/providers - List providers with their models

**Design:**
- Uses in-memory registry from PR 2 (no DB queries)
- Fast reads from cache populated at startup
- Grouped by provider for easy UI rendering

**Response models:**
- LlmModel - model info with capabilities, costs, creator
- LlmProvider - provider with nested models
- LlmModelsResponse - list + total count
- LlmProvidersResponse - grouped by provider

**Authentication:**
- Requires user auth (requires_user dependency)
- Public within authenticated sessions

**Integration:**
- Registered in rest_api.py at /api prefix
- Tagged with v2 + llm for OpenAPI grouping

**What's NOT included (later PRs):**
- Admin write API (PR 4)
- Block integration (PR 5)
- Redis cache (PR 6)

Lines: ~180 total
Files: 4 (3 new, 1 modified)
Review time: < 10 minutes
2026-03-16 15:44:53 +00:00
Bentlybro
9b93a956b4 style: fix trailing whitespace in registry.py 2026-03-16 15:44:45 +00:00
Bentlybro
b236719bbf fix(startup): handle missing AgentNode table in migrate_llm_models
Tests fail with 'relation "platform.AgentNode" does not exist' because
migrate_llm_models() runs during startup and queries a table that doesn't
exist in fresh test databases.

This is an existing bug in the codebase - the function has no error handling.

Wrap the call in try/except to gracefully handle test environments where
the AgentNode table hasn't been created yet.
2026-03-16 15:12:25 +00:00
Bentlybro
4f286f510f refactor: address CodeRabbit/Majdyz review feedback
- Fix ModelMetadata duplicate type collision by importing from blocks.llm
- Remove _json_to_dict helper, use dict() inline
- Add warning when Provider relation is missing (data corruption indicator)
- Optimize get_default_model_slug with next() (single sort pass)
- Optimize _build_schema_options to use list comprehension
- Move llm_registry import to top-level in rest_api.py
- Ensure max_output_tokens falls back to context_window when null

All critical and quick-win issues addressed.
2026-03-16 14:55:39 +00:00
Bentlybro
b1595d871d fix: address Sentry/CodeRabbit critical and major issues
**CRITICAL FIX - ModelMetadata instantiation:**
- Removed non-existent 'supports_vision' argument
- Added required fields: display_name, provider_name, creator_name, price_tier
- Handle nullable DB fields (Creator, priceTier, maxOutputTokens) safely
- Fallback: creator_name='Unknown' if no Creator, price_tier=1 if invalid

**MAJOR FIX - Preserve pricing unit:**
- Added 'unit' field to RegistryModelCost dataclass
- Prevents RUN vs TOKENS ambiguity in cached costs
- Convert Prisma enum to string when building cost objects

**MAJOR FIX - Deterministic default model:**
- Sort recommended models by display_name before selection
- Prevents non-deterministic results when multiple models are recommended
- Ensures consistent default across refreshes

**STARTUP IMPROVEMENT:**
- Added comment: graceful fallback OK for now (no blocks use registry yet)
- Will be stricter in PR #5 when block integration lands
- Added success log message for registry refresh

Fixes identified by Sentry (critical TypeError) and CodeRabbit review.
2026-03-16 14:55:39 +00:00
Bentlybro
29ab7f2d9c feat(platform): Add LLM registry core - DB layer + in-memory cache
Implements the registry core for dynamic LLM model management:

**DB Layer:**
- Fetch models with provider, costs, and creator relations
- Prisma query with includes for related data
- Convert DB records to typed dataclasses

**In-memory Cache:**
- Global dict for fast model lookups
- Atomic cache refresh with lock protection
- Schema options generation for UI dropdowns

**Public API:**
- get_model(slug) - lookup by slug
- get_all_models() - all models (including disabled)
- get_enabled_models() - enabled models only
- get_schema_options() - UI dropdown data
- get_default_model_slug() - recommended or first enabled
- refresh_llm_registry() - manual refresh trigger

**Integration:**
- Refresh at API startup (before block init)
- Graceful fallback if registry unavailable
- Enables blocks to consume registry data

**Models:**
- RegistryModel - full model with metadata
- RegistryModelCost - pricing configuration
- RegistryModelCreator - model creator info
- ModelMetadata - context window, capabilities

**Next PRs:**
- PR #3: Public read API (GET endpoints)
- PR #4: Admin write API (POST/PATCH/DELETE)
- PR #5: Block integration (update LLM block)
- PR #6: Redis cache (solve thundering herd)

Lines: ~230 (registry.py ~210, __init__.py ~30, model.py from draft)
Files: 4 (3 new, 1 modified)
2026-03-16 14:55:39 +00:00
Bentlybro
784936b323 revert: undo changes to graph.py
Reverting migrate_llm_models modifications per request.
Back to dev baseline for this file.
2026-03-16 14:55:39 +00:00
Bentlybro
f2ae38a1a7 fix(schema): address Majdyz review feedback
- Add FK constraints on LlmModelMigration (sourceModelSlug, targetModelSlug → LlmModel.slug)
- Remove unused @@index([credentialProvider]) on LlmModelCost
- Remove redundant @@index([isReverted]) on LlmModelMigration (covered by composite)
- Add documentation for credentialProvider field explaining its purpose
- Add reverse relation fields to LlmModel (SourceMigrations, TargetMigrations)

Fixes data integrity: typos in migration slugs now caught at DB level.
2026-03-16 14:52:19 +00:00
Bently
2ccfb4e4c1 Merge branch 'dev' into feat/llm-registry-schema 2026-03-10 17:52:01 +00:00
Otto
5641cdd3ca fix(backend): update test patches for validate_url → validate_url_host rename (#12358)
bfb843a renamed `validate_url` to `validate_url_host` in
`agent_browser`, `run_mcp_tool`, and MCP routes, but the corresponding
test files still patched the old name, causing `AttributeError` in CI.

Updates all mock patch targets and assertions across 3 test files:
- `agent_browser_test.py`
- `test_run_mcp_tool.py`  
- `mcp/test_routes.py`

---
Co-authored-by: Zamil Majdy (@majdyz) <zamil.majdy@agpt.co>
Co-authored-by: Reinier van der Leer (@Pwuts) <pwuts@agpt.co>
2026-03-10 17:22:11 +00:00
Bentlybro
c65e5c957a fix: isort import order 2026-03-10 16:43:34 +00:00
Bentlybro
54355a691b fix: use execute_raw_with_schema for proper multi-schema support
Per Sentry feedback: db.execute_raw ignores connection string's ?schema=
parameter and defaults to 'public' schema. This breaks in multi-schema setups.

Changes:
- Import execute_raw_with_schema from .db
- Use {schema_prefix} placeholder in query
- Call execute_raw_with_schema instead of db.execute_raw

This matches the pattern used in fix_llm_provider_credentials and other
schema-aware migrations. Works in both CI (public schema) and local
(platform schema from connection string).
2026-03-10 16:25:12 +00:00
Bentlybro
3cafa49c4c fix: remove hardcoded schema prefix from migrate_llm_models query
The raw SQL query in migrate_llm_models() hardcoded platform."AgentNode"
which fails in CI where tables are in 'public' schema (not 'platform').

This code exists in dev but only runs when LLM registry has data. With our
new schema, the migration tries to run at startup and fails in CI.

Changed: UPDATE platform."AgentNode" -> UPDATE "AgentNode"

Matches pattern of all other migrations - let connection string's default
schema handle routing.
2026-03-10 16:19:57 +00:00
Bentlybro
ded002a406 fix: remove CREATE SCHEMA to match CI environment
CI uses schema "public" as default (not "platform"), so creating
a platform schema then tables without prefix puts tables in public
but Prisma looks in platform.

Existing migrations don't create schema - they rely on connection
string's default. Remove CREATE SCHEMA IF NOT EXISTS to match.
2026-03-10 15:57:26 +00:00
Bentlybro
4fdf89c3be fix: remove schema prefix from migration SQL to match existing pattern
CI failing with 'relation "platform.AgentNode" does not exist' because
Prisma generates queries differently when tables are created with
explicit schema prefixes.

Existing AutoGPT migrations use:
  CREATE TABLE "AgentNode" (...)

Not:
  CREATE TABLE "platform"."AgentNode" (...)

The connection string's ?schema=platform handles schema selection,
so explicit prefixes aren't needed and cause compatibility issues.

Changes:
- Remove all "platform". prefixes from:
  * CREATE TYPE statements
  * CREATE TABLE statements
  * CREATE INDEX statements
  * ALTER TABLE statements
  * REFERENCES clauses in foreign keys

Now matches existing migration pattern exactly.
2026-03-10 15:38:41 +00:00
Bentlybro
d816bd739f fix: add partial unique indexes for data integrity
Per CodeRabbit feedback - fix 2 actual bugs:

1. Prevent multiple active migrations per source model
   - Add partial unique index: UNIQUE (sourceModelSlug) WHERE isReverted = false
   - Prevents ambiguous routing when resolving migrations

2. Allow both default and credential-specific costs
   - Remove @@unique([llmModelId, credentialProvider, unit])
   - Add 2 partial unique indexes:
     * UNIQUE (llmModelId, provider, unit) WHERE credentialId IS NULL (defaults)
     * UNIQUE (llmModelId, provider, credentialId, unit) WHERE credentialId IS NOT NULL (overrides)
   - Enables provider-level default costs + per-credential overrides

Schema comments document that these constraints exist in migration SQL.
2026-03-10 15:08:44 +00:00
Otto
bfb843a56e Merge commit from fork
* Fix SSRF via user-controlled ollama_host field

Validate ollama_host against BLOCKED_IP_NETWORKS before passing to
ollama.AsyncClient(). The server-configured default (env: OLLAMA_HOST)
is allowed without validation; user-supplied values that differ are
checked for private/internal IP resolution.

Fixes GHSA-6jx2-4h7q-3fx3

* Generalize validate_ollama_host to validate_host; fix description line length

* Rename to validate_untrusted_host with whitelist parameter

* Apply PR suggestion: include whitelist in error message; run formatting

* Move whitelist check after URL normalization; match on netloc

* revert unrelated formatting changes

* Dedup validate_url and validate_untrusted_host; normalize whitelist

* Move _resolve_and_check_blocked after calling functions

* dedup and clean up

* make trusted_hostnames truly optional

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-03-10 15:51:58 +01:00
Bentlybro
6a16376323 fix: remove multiSchema - follow existing AutoGPT pattern
Remove unnecessary multiSchema configuration that broke existing models.

AutoGPT uses connection string's ?schema=platform parameter as default,
not Prisma's multiSchema feature. Existing models (User, AgentGraph, etc.)
have no @@schema() directives and work fine.

Changes:
- Remove schemas = ["platform", "public"] from datasource
- Remove "multiSchema" from previewFeatures
- Remove all @@schema() directives from LLM models and enum

Migration SQL already creates tables in platform schema explicitly
(CREATE TABLE "platform"."LlmProvider" etc.) which is correct.

This matches the existing pattern used throughout the codebase.
2026-03-10 14:49:23 +00:00
Bentlybro
ed7b02ffb1 fix: address CodeRabbit design feedback
Per CodeRabbit review:

1. **Safety: Change capability defaults false → safer for partial seeding**
   - supportsTools: true → false
   - supportsJsonOutput: true → false
   - Prevents partially-seeded rows from being assumed capable

2. **Clarity: Rename supportsParallelTool → supportsParallelToolCalls**
   - More explicit about what the field represents

3. **Performance: Remove redundant indexes**
   - Drop @@index([llmModelId]) - covered by unique constraint
   - Drop @@index([sourceModelSlug]) - covered by composite index
   - Reduces write overhead and storage

4. **Documentation: Acknowledge customCreditCost limitation**
   - It's unit-agnostic (doesn't distinguish RUN vs TOKENS)
   - Noted as TODO for follow-up PR with proper unit-aware override

Schema + migration both updated to match.
2026-03-10 14:27:42 +00:00
Bentlybro
d064198dd1 fix: add @@schema("platform") to LlmCostUnit enum
Sentry caught this - enums also need @@schema directive with multiSchema enabled.
Without it, Prisma looks for enum in public schema but it's created in platform.
2026-03-10 14:23:24 +00:00
Bentlybro
01ad033b2b feat: add database CHECK constraints for data integrity
Per CodeRabbit feedback - enforce numeric domain rules at DB level:

Migration:
- priceTier: CHECK (priceTier BETWEEN 1 AND 3)
- creditCost: CHECK (creditCost >= 0)
- nodeCount: CHECK (nodeCount >= 0)
- customCreditCost: CHECK (customCreditCost IS NULL OR customCreditCost >= 0)

Schema comments:
- Document constraints inline for developer visibility

Prevents invalid data (negative costs, out-of-range tiers) from
entering the database, matching backend/blocks/llm.py contract.
2026-03-10 14:20:07 +00:00
Bentlybro
56bcbda054 fix: use @@schema() instead of @@map() for platform schema + create schema in migration
Critical fixes from PR review:

1. Replace @@map("platform.ModelName") with @@schema("platform")
   - Sentry correctly identified: Prisma was looking for literal table "platform.LlmProvider" with dot
   - Proper syntax: enable multiSchema feature + use @@schema directive

2. Create platform schema in migration
   - CI failed: schema "platform" does not exist
   - Add CREATE SCHEMA IF NOT EXISTS at start of migration

Schema changes:
- datasource: add schemas = ["platform", "public"]
- generator: add "multiSchema" to previewFeatures
- All 5 models: @@map() → @@schema("platform")

Migration changes:
- Add CREATE SCHEMA IF NOT EXISTS "platform" before enum creation

Fixes CI failure and Sentry-identified bug.
2026-03-10 14:15:35 +00:00
Abhimanyu Yadav
684845d946 fix(frontend/builder): handle discriminated unions and improve node layout (#12354)
## Summary
- **Discriminated union support (oneOf)**: Added a new `OneOfField`
component that properly
renders Pydantic discriminated unions. Hides the unusable parent object
handle, auto-populates
the discriminator value, shows a dropdown with variant titles (e.g.,
"Username" / "UserId"), and
filters out the internal discriminator field from the form.
Non-discriminated `oneOf` schemas
  fall back to existing `AnyOfField` behavior.
- **Collapsible object outputs**: Object-type outputs with nested keys
(e.g.,
`PersonLookupResponse.Url`, `PersonLookupResponse.profile`) are now
collapsed by default behind a
caret toggle. Nested keys show short names instead of the full
`Parent.Key` prefix.
- **Node layout cleanup**: Removed excessive bottom margin (`mb-6`) from
`FormRenderer`, hide the
Advanced toggle when no advanced fields exist, and add rounded bottom
corners on OUTPUT-type
  blocks.

<img width="440" height="427" alt="Screenshot 2026-03-10 at 11 31 55 AM"
src="https://github.com/user-attachments/assets/06cc5414-4e02-4371-bdeb-1695e7cb2c97"
/>
<img width="371" height="320" alt="Screenshot 2026-03-10 at 11 36 52 AM"
src="https://github.com/user-attachments/assets/1a55f87a-c602-4f4d-b91b-6e49f810e5d5"
/>

  ## Test plan
- [x] Add a Twitter Get User block — verify "Identifier" shows a
dropdown (Username/UserId) with
no unusable parent handle, discriminator field is hidden, and the block
can run without staying
  INCOMPLETE
- [x] Add any block with object outputs (e.g., PersonLookupResponse) —
verify nested keys are
  collapsed by default and expand on click with short labels
- [x] Verify blocks without advanced fields don't show the Advanced
toggle
- [x] Verify existing `anyOf` schemas (optional types, 3+ variant
unions) still render correctly
  - [x] Check OUTPUT-type blocks have rounded bottom corners

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
Co-authored-by: eureka928 <meobius123@gmail.com>
2026-03-10 14:13:32 +00:00
Bentlybro
d40efc6056 feat(platform): Add LLM registry database schema
Add Prisma schema and migration for dynamic LLM model registry:

Schema additions:
- LlmProvider: Registry of LLM providers (OpenAI, Anthropic, etc.)
- LlmModel: Individual models with capabilities and metadata
- LlmModelCost: Per-model pricing configuration
- LlmModelCreator: Model creators/trainers (OpenAI, Meta, etc.)
- LlmModelMigration: Track model migrations and reverts
- LlmCostUnit enum: RUN vs TOKENS pricing units

Key features:
- Model-specific capabilities (tools, JSON, reasoning, parallel calls)
- Flexible creator/provider separation (e.g., Meta model via Hugging Face)
- Migration tracking with custom pricing overrides
- Indexes for performance on common queries

Part 1 of incremental LLM registry implementation.
Refs: Draft PR #11699
2026-03-10 13:22:05 +00:00
Bently
6a6b23c2e1 fix(frontend): Remove unused Otto Server Action causing 107K+ errors (#12336)
## Summary

Fixes [OPEN-3025](https://linear.app/autogpt/issue/OPEN-3025) —
**107,571+ Server Action errors** in production

Removes the orphaned `askOtto` Server Action that was left behind after
the Otto chat widget removal in PR #12082.

## Problem

Next.js Server Actions that are never imported are excluded from the
server manifest. Old client bundles still reference the action ID,
causing "not found" errors.

**Sentry impact:**
- **BUILDER-3BN:** 107,571 events
- **BUILDER-729:** 285 events  
- **BUILDER-3QH:** 1,611 events
- **36+ users affected**

## Root Cause

1. **Mar 2025:** Otto widget added to `/build` page with `askOtto`
Server Action
2. **Feb 2026:** Otto widget removed (PR #12082), but `actions.ts` left
behind
3. **Result:** Dead code → not in manifest → errors

## Evidence

```bash
# Zero imports across frontend:
grep -r "askOtto" src/ --exclude="actions.ts"
# → No results

# Server manifest missing the action:
cat .next/server/server-reference-manifest.json
# → Only includes login/supabase actions, NOT build/actions
```

## Changes

-  Delete
`autogpt_platform/frontend/src/app/(platform)/build/actions.ts`

## Testing

1. Verify no imports of `askOtto` in codebase 
2. Check Sentry for error drop after deploy
3. Monitor for new "Server Action not found" errors

## Checklist

- [x] Dead code confirmed (zero imports)
- [x] Sentry issues documented
- [x] Clear commit message with context
2026-03-10 09:03:38 +00:00
Dream
d0a1d72e8a fix(frontend/builder): batch undo history for cascading operations (#12344)
## Summary

Fixes undo in the Builder not working correctly when deleting nodes.
When a node is deleted, React Flow fires `onNodesChange` (node removal)
and `onEdgesChange` (cascading edge cleanup) as separate callbacks —
each independently pushing to the undo history stack. This creates
intermediate states that break undo:

- Single undo restores a partial state (e.g. edges pointing to a deleted
node)
- Multiple undos required to fully restore the graph
- Redo also produces inconsistent states

Resolves #10999

### Changes 🏗️

- **`historyStore.ts`** — Added microtask-based batching to
`pushState()`. Multiple calls within the same synchronous execution
(same event loop tick) are coalesced into a single history entry,
keeping only the first pre-change snapshot. Uses `queueMicrotask` so all
cascading store updates from a single user action settle before the
history entry is committed.
- Reset `pendingState` in `initializeHistory()` and `clear()` to prevent
stale batched state from leaking across graph loads or navigation.

**Side benefit:** Copy/paste operations that add multiple nodes and
edges now also produce a single history entry instead of one per
node/edge.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Place 3 blocks A, B, C and connect A→B→C
  - [x] Delete block C (removes node + cascading edge B→C)
  - [x] Delete connection A→B
  - [x] Undo — connection A→B restored (single undo, not multiple)
  - [x] Undo — block C and connection B→C restored
  - [x] Redo — block C removed again with its connections
- [x] Copy/paste multiple connected blocks — single undo reverts entire
paste

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
Co-authored-by: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com>
2026-03-10 04:55:07 +00:00
Zamil Majdy
f1945d6a2f feat(platform/copilot): @@agptfile: file-ref protocol for tool call inputs + block input toggle (#12332)
## Summary

- **Problem**: When the LLM calls a tool with large file content, it
must rewrite all content token-by-token. This is wasteful since the
files are already accessible on disk.
- **Solution**: Introduces an \`@@agptfile:\` reference protocol. The
LLM passes a file path reference; the processor loads and substitutes
the content before executing the tool.

### Protocol

\`\`\`
@@agptfile:<uri>[<start>-<end>]
\`\`\`

**Supported URI types:**
| URI | Source |
|-----|--------|
| \`workspace://<file_id>\` | Persistent workspace file by ID |
| \`workspace:///<path>\` | Workspace file by virtual path |
| \`/absolute/path\` | Absolute host or sandbox path |

**Line range** is optional; omitting it reads the whole file.

### Backend changes

- Rename \`@file:\` → \`@@agptfile:\` prefix for uniqueness; extract
\`FILE_REF_PREFIX\` constant
- Extract shared execution-context ContextVars into
\`backend/copilot/context.py\` — eliminates duplicate ContextVar objects
that caused \`e2b_file_tools.py\` to always see empty context
- \`tool_adapter.py\` imports ContextVars from \`context.py\` (single
source of truth)
- \`expand_file_refs_in_string\` raises \`FileRefExpansionError\` on
failure (instead of inline error strings), blocking tool execution and
returning a clear error hint to the model
- Tighten URI regex: only expand refs starting with \`workspace://\` or
\`/\`
- Aggregate budget: 1 MB total expansion cap across all refs in one
string
- Per-file cap: 200 KB per individual ref
- Fix \`_read_file_handler\` to pass \`get_sdk_cwd()\` to
\`is_allowed_local_path\` — ephemeral working directory files were
incorrectly blocked
- Fix \`_is_allowed_local\` in \`e2b_file_tools.py\` to pass
\`get_sdk_cwd()\`
- Restrict local path allow-list to \`tool-results/\` subdirectory only
(was entire session project dir)
- Add \`raise_on_error\` param + remove two-pass \`_FILE_REF_ERROR_RE\`
detection
- Update system prompt docs and tool_adapter error messages

### Frontend changes

- \`BlockInputCard\`: hidden by default with Show/Hide toggle + \`mb-2\`
spacing

## Test plan

- [ ] \`poetry run pytest backend/copilot/ -x
--ignore=backend/copilot/sdk/file_ref_integration_test.py\` passes
- [ ] \`@@agptfile:workspace:///<path>[1-50]\` expands correctly in tool
calls
- [ ] Invalid line ranges produce \`[file-ref error: ...]\` inline
messages
- [ ] Files outside \`sdk_cwd\` / \`tool-results/\` are rejected
- [ ] Block input card shows hidden by default with toggle
2026-03-09 18:39:13 +00:00
Zamil Majdy
6491cb1e23 feat(copilot): local agent generation with validation, fixing, MCP & sub-agent support (#12238)
## Summary

Port the agent generation pipeline from the external AgentGenerator
service into local copilot tools, making the Claude Agent SDK itself
handle validation, fixing, and block recommendation — no separate inner
LLM calls needed.

Key capabilities:
- **Local agent generation**: Create, edit, and customize agents
entirely within the SDK session
- **Graph validation**: 9 validation checks (block existence, link
references, type compatibility, IO blocks, etc.)
- **Graph fixing**: 17+ auto-fix methods (ID repair, link rewiring, type
conversion, credential stripping, dynamic block sink names, etc.)
- **MCP tool blocks**: Guide and fixer support for MCPToolBlock nodes
with proper dynamic input schema handling
- **Sub-agent composition**: AgentExecutorBlock support with library
agent schema enrichment
- **Embedding fallback**: Falls back to OpenRouter for embeddings when
`openai_internal_api_key` is unavailable
- **Actionable error messages**: Excluded block types (MCP, Agent)
return specific hints redirecting to the correct tool

### New Tools
- `validate_agent_graph` — run 9 validation checks on agent JSON
- `fix_agent_graph` — apply 17+ auto-fixes to agent JSON
- `get_blocks_for_goal` — recommend blocks for a given goal (with
optimized descriptions)

### Refactored Tools
- `create_agent`, `edit_agent`, `customize_agent` — accept `agent_json`
for local generation with shared fix→validate→save pipeline
- `find_block` — added `include_schemas` parameter, excludes MCP/Agent
blocks with actionable hints
- `run_block` — actionable error messages for excluded block types
- `find_library_agent` — enriched with `graph_version`, `input_schema`,
`output_schema` for sub-agent composition

### Architecture
- Split 2,558-line `validation.py` into `fixer.py`, `validator.py`,
`helpers.py`, `pipeline.py`
- Extracted shared `fix_validate_and_save()` pipeline (was duplicated
across 3 tools)
- Shared `OPENROUTER_BASE_URL` constant across codebase
- Comprehensive test coverage: 78+ unit tests for fixer/validator, 8
run_block tests, 17 SDK compat tests

## Test plan
- [x] `poetry run format` passes
- [x] `poetry run pytest -s -vvv backend/copilot/` — all tests pass
- [x] CI green on all Python versions (3.11, 3.12, 3.13)
- [x] Manual E2E: copilot generates agents with correct IO blocks,
links, and node structure
- [x] Manual E2E: MCP tool blocks use bare field names for dynamic
inputs
- [x] Manual E2E: sub-agent composition with AgentExecutorBlock
2026-03-09 16:10:22 +00:00
nKOxxx
c7124a5240 Add documentation for Google Gemini integration (#12283)
## Summary
Adding comprehensive documentation for Google Gemini integration with
AutoGPT.

## Changes
- Added setup instructions for Gemini API
- Documented configuration options
- Added examples and best practices

## Related Issues
N/A - Documentation improvement

## Testing
- Verified documentation accuracy
- Tested all code examples

## Checklist
- [x] Code follows project style
- [x] Documentation updated
- [x] Tests pass (if applicable)
2026-03-09 15:13:28 +00:00
Zamil Majdy
5537cb2858 dx: add shared Claude Code skills as auto-triggered guidelines (#12297)
## Summary
- Add 8 Claude Code skills under \`.claude/skills/\` that act as
**auto-triggered guidelines** — the LLM invokes them automatically based
on context, no manual \`/command\` needed
- Skills: \`pr-review\`, \`pr-create\`, \`new-block\`,
\`openapi-regen\`, \`backend-check\`, \`frontend-check\`,
\`worktree-setup\`, \`code-style\`
- Each skill has an explicit TRIGGER condition so the LLM knows when to
apply it without being asked

## Changes

### Skills (all auto-triggered by context)
| Skill | Trigger |
|-------|---------|
| \`pr-review\` | User shares a PR URL or asks to address review
comments |
| \`pr-create\` | User asks to create a PR, push changes for review, or
submit work |
| \`new-block\` | User asks to create a new block or add a new
integration |
| \`openapi-regen\` | API routes change, new endpoints added, or
frontend types are stale |
| \`backend-check\` | Backend Python code has been modified |
| \`frontend-check\` | Frontend TypeScript/React code has been modified
|
| \`worktree-setup\` | User asks to work on a branch in isolation or set
up a worktree |
| \`code-style\` | Writing or reviewing Python code |

## Test plan
- [ ] Verify skills appear automatically in Claude Code when context
matches (no \`/command\` needed)
- [ ] Modify frontend code — confirm \`frontend-check\` fires
automatically
- [ ] Ask Claude to "create a PR" — confirm \`pr-create\` fires without
\`/pr-create\`
- [ ] Share a PR URL — confirm \`pr-review\` fires automatically

---------

Co-authored-by: Krzysztof Czerwinski <kpczerwinski@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 15:10:38 +00:00
Zamil Majdy
aef5f6d666 feat(copilot): E2B sandbox auto-pause between turns to eliminate idle billing (#12330)
## Summary

### Before
- E2B sandboxes ran continuously between CoPilot turns, billing for idle
time
- Sandbox timeout caused **termination** (kill), losing all session
state
- No explicit cleanup when sessions were deleted — sandboxes leaked
- Single timeout concept with no separation between pause and kill
semantics

### After
- **Per-turn pause**: `pause_sandbox()` is called in the `finally` block
after every CoPilot turn, stopping billing instantly between turns
(paused sandboxes cost \$0 compute)
- **Auto-pause safety net**: Sandboxes are created with
`lifecycle={"on_timeout": "pause"}` (`pause_timeout` = 4h default) so
they auto-pause rather than terminate if the explicit pause is missed
- **Auto-reconnect**: `AsyncSandbox.connect()` in e2b SDK v2
auto-resumes paused sandboxes transparently — no extra code needed
- **Session delete cleanup**: `kill_sandbox()` is now called in
`delete_chat_session()` to explicitly terminate sandboxes and free
resources
- **Two distinct timeouts**: `pause_timeout` (4h, e2b auto-pause) vs
`redis_ttl` (12h, session key lifetime)

### Key Changes

| File | Change |
|------|--------|
| `pyproject.toml` | Bump `e2b-code-interpreter` `1.x` → `2.x` |
| `e2b_sandbox.py` | Add `pause_sandbox()`, `kill_sandbox()`,
`_act_on_sandbox()` helper; `lifecycle={"on_timeout": "pause"}`;
separate `pause_timeout` / `redis_ttl` params |
| `sdk/service.py` | Call `pause_sandbox()` in `finally` block
**before** transcript upload; use walrus operator for type-safe
`e2b_api_key` narrowing |
| `model.py` | Call `kill_sandbox()` in `delete_chat_session()`; inline
import to avoid circular dependency |
| `config.py` | Add `e2b_active` property; rename `e2b_sandbox_timeout`
default to 4h |
| `e2b_sandbox_test.py` | Add `test_pause_then_reconnect_reuses_sandbox`
test; update all `sandbox_timeout` → `pause_timeout` |

### Verified E2E
- Used real `E2B_API_KEY` from k8s dev cluster to manually verify:
sandbox created → paused → `is_running() == False` → reconnected via
`connect()` → state preserved → killed

## Test plan
- [x] `poetry run pytest backend/copilot/tools/e2b_sandbox_test.py` —
all 19 tests pass
- [x] CI: test (3.11, 3.12, 3.13), types — all green
- [x] E2E verified with real E2B credentials
2026-03-09 14:55:10 +00:00
Ubbe
8063391d0a feat(frontend/copilot): pin interactive tool cards outside reasoning collapse (#12346)
## Summary

<img width="400" height="227" alt="Screenshot 2026-03-09 at 22 43 10"
src="https://github.com/user-attachments/assets/0116e260-860d-4466-9763-e02de2766e50"
/>

<img width="600" height="618" alt="Screenshot 2026-03-09 at 22 43 14"
src="https://github.com/user-attachments/assets/beaa6aca-afa8-483f-ac06-439bf162c951"
/>

- When the copilot stream finishes, tool calls that require user
interaction (credentials, inputs, clarification) are now **pinned**
outside the "Show reasoning" collapse instead of being hidden
- Added `isInteractiveToolPart()` helper that checks tool output's
`type` field against a set of interactive response types
- Modified `splitReasoningAndResponse()` to extract interactive tools
from reasoning into the visible response section
- Added styleguide section with 3 demos: `setup_requirements`,
`agent_details`, and `agent_saved` pinning scenarios

### Interactive response types kept visible:
`setup_requirements`, `agent_details`, `block_details`, `need_login`,
`input_validation_error`, `clarification_needed`, `suggested_goal`,
`agent_preview`, `agent_saved`

Error responses remain in reasoning (LLM explains them in final text).

Closes SECRT-2088

## Test plan
- [ ] Verify copilot stream with interactive tool (e.g. run_agent
requiring credentials) keeps the tool card visible after stream ends
- [ ] Verify non-interactive tools (find_block, bash_exec) still
collapse into "Show reasoning"
- [ ] Verify styleguide page at `/copilot/styleguide` renders the new
"Reasoning Collapse: Interactive Tool Pinning" section correctly
- [ ] Verify `pnpm types`, `pnpm lint`, `pnpm format` all pass

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 23:12:14 +08:00
Otto
0bbb12d688 fix(frontend/copilot): hide New Chat button on Autopilot homepage (#12321)
Requested by @0ubbe

The **New Chat** button was visible on the Autopilot homepage where
clicking it has no effect (since `sessionId` is already `null`). This
hides the button when no chat session is active, so it only appears when
the user is viewing a conversation and wants to start a new one.

**Changes:**
- `ChatSidebar.tsx` — hide button in both collapsed and expanded sidebar
states when `sessionId` is null
- `MobileDrawer.tsx` — same fix for mobile drawer

---
Co-authored-by: Ubbe <ubbe@users.noreply.github.com>
2026-03-09 22:41:11 +08:00
140 changed files with 20165 additions and 8383 deletions

View File

@@ -0,0 +1,17 @@
---
name: backend-check
description: Run the full backend formatting, linting, and test suite. Ensures code quality before commits and PRs. TRIGGER when backend Python code has been modified and needs validation.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Backend Check
## Steps
1. **Format**: `poetry run format` — runs formatting AND linting. NEVER run ruff/black/isort individually
2. **Fix** any remaining errors manually, re-run until clean
3. **Test**: `poetry run test` (runs DB setup + pytest). For specific files: `poetry run pytest -s -vvv <test_files>`
4. **Snapshots** (if needed): `poetry run pytest path/to/test.py --snapshot-update` — review with `git diff`

View File

@@ -0,0 +1,35 @@
---
name: code-style
description: Python code style preferences for the AutoGPT backend. Apply when writing or reviewing Python code. TRIGGER when writing new Python code, reviewing PRs, or refactoring backend code.
user-invocable: false
metadata:
author: autogpt-team
version: "1.0.0"
---
# Code Style
## Imports
- **Top-level only** — no local/inner imports. Move all imports to the top of the file.
## Typing
- **No duck typing** — avoid `hasattr`, `getattr`, `isinstance` for type dispatch. Use proper typed interfaces, unions, or protocols.
- **Pydantic models** over dataclass, namedtuple, or raw dict for structured data.
- **No linter suppressors** — avoid `# type: ignore`, `# noqa`, `# pyright: ignore` etc. 99% of the time the right fix is fixing the type/code, not silencing the tool.
## Code Structure
- **List comprehensions** over manual loop-and-append.
- **Early return** — guard clauses first, avoid deep nesting.
- **Flatten inline** — prefer short, concise expressions. Reduce `if/else` chains with direct returns or ternaries when readable.
- **Modular functions** — break complex logic into small, focused functions rather than long blocks with nested conditionals.
## Review Checklist
Before finishing, always ask:
- Can any function be split into smaller pieces?
- Is there unnecessary nesting that an early return would eliminate?
- Can any loop be a comprehension?
- Is there a simpler way to express this logic?

View File

@@ -0,0 +1,16 @@
---
name: frontend-check
description: Run the full frontend formatting, linting, and type checking suite. Ensures code quality before commits and PRs. TRIGGER when frontend TypeScript/React code has been modified and needs validation.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Frontend Check
## Steps (in order)
1. **Format**: `pnpm format` — NEVER run individual formatters
2. **Lint**: `pnpm lint` — fix errors, re-run until clean
3. **Types**: `pnpm types` — if it keeps failing after multiple attempts, stop and ask the user

View File

@@ -0,0 +1,29 @@
---
name: new-block
description: Create a new backend block following the Block SDK Guide. Guides through provider configuration, schema definition, authentication, and testing. TRIGGER when user asks to create a new block, add a new integration, or build a new node for the graph editor.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# New Block Creation
Read `docs/platform/block-sdk-guide.md` first for the full guide.
## Steps
1. **Provider config** (if external service): create `_config.py` with `ProviderBuilder`
2. **Block file** in `backend/blocks/` (from `autogpt_platform/backend/`):
- Generate a UUID once with `uuid.uuid4()`, then **hard-code that string** as `id` (IDs must be stable across imports)
- `Input(BlockSchema)` and `Output(BlockSchema)` classes
- `async def run` that `yield`s output fields
3. **Files**: use `store_media_file()` with `"for_block_output"` for outputs
4. **Test**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[MyBlock]' -xvs`
5. **Format**: `poetry run format`
## Rules
- Analyze interfaces: do inputs/outputs connect well with other blocks in a graph?
- Use top-level imports, avoid duck typing
- Always use `for_block_output` for block outputs

View File

@@ -0,0 +1,28 @@
---
name: openapi-regen
description: Regenerate the OpenAPI spec and frontend API client. Starts the backend REST server, fetches the spec, and regenerates the typed frontend hooks. TRIGGER when API routes change, new endpoints are added, or frontend API types are stale.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# OpenAPI Spec Regeneration
## Steps
1. **Run end-to-end** in a single shell block (so `REST_PID` persists):
```bash
cd autogpt_platform/backend && poetry run rest &
REST_PID=$!
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && kill $REST_PID && exit 1; done
cd ../frontend && pnpm generate:api:force
kill $REST_PID
pnpm types && pnpm lint && pnpm format
```
## Rules
- Always use `pnpm generate:api:force` (not `pnpm generate:api`)
- Don't manually edit files in `src/app/api/__generated__/`
- Generated hooks follow: `use{Method}{Version}{OperationName}`

View File

@@ -0,0 +1,31 @@
---
name: pr-create
description: Create a pull request for the current branch. TRIGGER when user asks to create a PR, open a pull request, push changes for review, or submit work for merging.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Create Pull Request
## Steps
1. **Check for existing PR**: `gh pr view --json url -q .url 2>/dev/null` — if a PR already exists, output its URL and stop
2. **Understand changes**: `git status`, `git diff dev...HEAD`, `git log dev..HEAD --oneline`
3. **Read PR template**: `.github/PULL_REQUEST_TEMPLATE.md`
4. **Draft PR title**: Use conventional commits format (see CLAUDE.md for types and scopes)
5. **Fill out PR template** as the body — be thorough in the Changes section
6. **Format first** (if relevant changes exist):
- Backend: `cd autogpt_platform/backend && poetry run format`
- Frontend: `cd autogpt_platform/frontend && pnpm format`
- Fix any lint errors, then commit formatting changes before pushing
7. **Push**: `git push -u origin HEAD`
8. **Create PR**: `gh pr create --base dev`
9. **Output** the PR URL
## Rules
- Always target `dev` branch
- Do NOT run tests — CI will handle that
- Use the PR template from `.github/PULL_REQUEST_TEMPLATE.md`

View File

@@ -0,0 +1,51 @@
---
name: pr-review
description: Address all open PR review comments systematically. Fetches comments, addresses each one, reacts +1/-1, and replies when clarification is needed. Keeps iterating until all comments are addressed and CI is green. TRIGGER when user shares a PR URL, asks to address review comments, fix PR feedback, or respond to reviewer comments.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# PR Review Comment Workflow
## Steps
1. **Find PR**: `gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT`
2. **Fetch comments** (all three sources):
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` (top-level reviews)
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` (inline review comments)
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` (PR conversation comments)
3. **Skip** comments already reacted to by PR author
4. **For each unreacted comment**:
- Read referenced code, make the fix (or reply if you disagree/need info)
- **Inline review comments** (`pulls/{N}/comments`):
- React: `gh api repos/.../pulls/comments/{ID}/reactions -f content="+1"` (or `-1`)
- Reply: `gh api repos/.../pulls/{N}/comments/{ID}/replies -f body="..."`
- **PR conversation comments** (`issues/{N}/comments`):
- React: `gh api repos/.../issues/comments/{ID}/reactions -f content="+1"` (or `-1`)
- No threaded replies — post a new issue comment if needed
- **Top-level reviews**: no reaction API — address in code, reply via issue comment if needed
5. **Include autogpt-reviewer bot fixes** too
6. **Format**: `cd autogpt_platform/backend && poetry run format`, `cd autogpt_platform/frontend && pnpm format`
7. **Commit & push**
8. **Re-fetch comments** immediately — address any new unreacted ones before waiting on CI
9. **Stay productive while CI runs** — don't idle. In priority order:
- Run any pending local tests (`poetry run pytest`, e2e, etc.) and fix failures
- Address any remaining comments
- Only poll `gh pr checks {N}` as the last resort when there's truly nothing left to do
10. **If CI fails** — fix, go back to step 6
11. **Re-fetch comments again** after CI is green — address anything that appeared while CI was running
12. **Done** only when: all comments reacted AND CI is green.
## CRITICAL: Do Not Stop
**Loop is: address → format → commit → push → re-check comments → run local tests → wait CI → re-check comments → repeat.**
Never idle. If CI is running and you have nothing to address, run local tests. Waiting on CI is the last resort.
## Rules
- One todo per comment
- For inline review comments: reply on existing threads. For PR conversation comments: post a new issue comment (API doesn't support threaded replies)
- React to every comment: +1 addressed, -1 disagreed (with explanation)

View File

@@ -0,0 +1,45 @@
---
name: worktree-setup
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, generates Prisma client, and optionally starts the app (with port conflict resolution) or runs tests. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Worktree Setup
## Preferred: Use Branchlet
The repo has a `.branchlet.json` config — it handles env file copying, dependency installation, and Prisma generation automatically.
```bash
npm install -g branchlet # install once
branchlet create -n <name> -s <source-branch> -b <new-branch>
branchlet list --json # list all worktrees
```
## Manual Fallback
If branchlet isn't available:
1. `git worktree add ../<RepoName><N> <branch-name>`
2. Copy `.env` files: `backend/.env`, `frontend/.env`, `autogpt_platform/.env`, `db/docker/.env`
3. Install deps:
- `cd autogpt_platform/backend && poetry install && poetry run prisma generate`
- `cd autogpt_platform/frontend && pnpm install`
## Running the App
Free ports first — backend uses: 8001, 8002, 8003, 8005, 8006, 8007, 8008.
```bash
for port in 8001 8002 8003 8005 8006 8007 8008; do
lsof -ti :$port | xargs kill -9 2>/dev/null || true
done
cd <worktree>/autogpt_platform/backend && poetry run app
```
## CoPilot Testing Gotcha
SDK mode spawns a Claude subprocess — **won't work inside Claude Code**. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.

View File

@@ -1,8 +1,4 @@
from datetime import datetime
from typing import Any, Literal, Optional
import prisma.enums
from pydantic import BaseModel, EmailStr
from pydantic import BaseModel
from backend.data.model import UserTransaction
from backend.util.models import Pagination
@@ -18,42 +14,3 @@ class UserHistoryResponse(BaseModel):
class AddUserCreditsResponse(BaseModel):
new_balance: int
transaction_key: str
class CreateInvitedUserRequest(BaseModel):
email: EmailStr
name: Optional[str] = None
class InvitedUserResponse(BaseModel):
id: str
email: str
status: prisma.enums.InvitedUserStatus
auth_user_id: Optional[str] = None
name: Optional[str] = None
tally_understanding: Optional[dict[str, Any]] = None
tally_status: prisma.enums.TallyComputationStatus
tally_computed_at: Optional[datetime] = None
tally_error: Optional[str] = None
created_at: datetime
updated_at: datetime
class InvitedUsersResponse(BaseModel):
invited_users: list[InvitedUserResponse]
class BulkInvitedUserRowResponse(BaseModel):
row_number: int
email: Optional[str] = None
name: Optional[str] = None
status: Literal["CREATED", "SKIPPED", "ERROR"]
message: str
invited_user: Optional[InvitedUserResponse] = None
class BulkInvitedUsersResponse(BaseModel):
created_count: int
skipped_count: int
error_count: int
results: list[BulkInvitedUserRowResponse]

View File

@@ -1,143 +0,0 @@
import logging
from autogpt_libs.auth import get_user_id, requires_admin_user
from fastapi import APIRouter, File, Security, UploadFile
from backend.data.invited_user import (
BulkInvitedUsersResult,
InvitedUserRecord,
bulk_create_invited_users_from_file,
create_invited_user,
list_invited_users,
retry_invited_user_tally,
revoke_invited_user,
)
from .model import (
BulkInvitedUserRowResponse,
BulkInvitedUsersResponse,
CreateInvitedUserRequest,
InvitedUserResponse,
InvitedUsersResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/admin",
tags=["users", "admin"],
dependencies=[Security(requires_admin_user)],
)
def _to_response(invited_user: InvitedUserRecord) -> InvitedUserResponse:
return InvitedUserResponse(**invited_user.model_dump())
def _to_bulk_response(result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
return BulkInvitedUsersResponse(
created_count=result.created_count,
skipped_count=result.skipped_count,
error_count=result.error_count,
results=[
BulkInvitedUserRowResponse(
row_number=row.row_number,
email=row.email,
name=row.name,
status=row.status,
message=row.message,
invited_user=(
_to_response(row.invited_user)
if row.invited_user is not None
else None
),
)
for row in result.results
],
)
@router.get(
"/invited-users",
response_model=InvitedUsersResponse,
summary="List Invited Users",
)
async def get_invited_users(
admin_user_id: str = Security(get_user_id),
) -> InvitedUsersResponse:
logger.info("Admin user %s requested invited users", admin_user_id)
invited_users = await list_invited_users()
return InvitedUsersResponse(
invited_users=[_to_response(invited_user) for invited_user in invited_users]
)
@router.post(
"/invited-users",
response_model=InvitedUserResponse,
summary="Create Invited User",
)
async def create_invited_user_route(
request: CreateInvitedUserRequest,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s created invited user for %s",
admin_user_id,
request.email,
)
invited_user = await create_invited_user(request.email, request.name)
return _to_response(invited_user)
@router.post(
"/invited-users/bulk",
response_model=BulkInvitedUsersResponse,
summary="Bulk Create Invited Users",
operation_id="postV2BulkCreateInvitedUsers",
)
async def bulk_create_invited_users_route(
file: UploadFile = File(...),
admin_user_id: str = Security(get_user_id),
) -> BulkInvitedUsersResponse:
logger.info(
"Admin user %s bulk invited users from %s",
admin_user_id,
file.filename or "<unnamed>",
)
content = await file.read()
result = await bulk_create_invited_users_from_file(file.filename, content)
return _to_bulk_response(result)
@router.post(
"/invited-users/{invited_user_id}/revoke",
response_model=InvitedUserResponse,
summary="Revoke Invited User",
)
async def revoke_invited_user_route(
invited_user_id: str,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
invited_user = await revoke_invited_user(invited_user_id)
return _to_response(invited_user)
@router.post(
"/invited-users/{invited_user_id}/retry-tally",
response_model=InvitedUserResponse,
summary="Retry Invited User Tally",
)
async def retry_invited_user_tally_route(
invited_user_id: str,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s retried Tally seed for invited user %s",
admin_user_id,
invited_user_id,
)
invited_user = await retry_invited_user_tally(invited_user_id)
return _to_response(invited_user)

View File

@@ -1,165 +0,0 @@
from datetime import datetime, timezone
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import prisma.enums
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from backend.data.invited_user import (
BulkInvitedUserRowResult,
BulkInvitedUsersResult,
InvitedUserRecord,
)
from .user_admin_routes import router as user_admin_router
app = fastapi.FastAPI()
app.include_router(user_admin_router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def _sample_invited_user() -> InvitedUserRecord:
now = datetime.now(timezone.utc)
return InvitedUserRecord(
id="invite-1",
email="invited@example.com",
status=prisma.enums.InvitedUserStatus.INVITED,
auth_user_id=None,
name="Invited User",
tally_understanding=None,
tally_status=prisma.enums.TallyComputationStatus.PENDING,
tally_computed_at=None,
tally_error=None,
created_at=now,
updated_at=now,
)
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
return BulkInvitedUsersResult(
created_count=1,
skipped_count=1,
error_count=0,
results=[
BulkInvitedUserRowResult(
row_number=1,
email="invited@example.com",
name=None,
status="CREATED",
message="Invite created",
invited_user=_sample_invited_user(),
),
BulkInvitedUserRowResult(
row_number=2,
email="duplicate@example.com",
name=None,
status="SKIPPED",
message="An invited user with this email already exists",
invited_user=None,
),
],
)
def test_get_invited_users(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.list_invited_users",
AsyncMock(return_value=[_sample_invited_user()]),
)
response = client.get("/admin/invited-users")
assert response.status_code == 200
data = response.json()
assert len(data["invited_users"]) == 1
assert data["invited_users"][0]["email"] == "invited@example.com"
assert data["invited_users"][0]["status"] == "INVITED"
def test_create_invited_user(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.create_invited_user",
AsyncMock(return_value=_sample_invited_user()),
)
response = client.post(
"/admin/invited-users",
json={"email": "invited@example.com", "name": "Invited User"},
)
assert response.status_code == 200
data = response.json()
assert data["email"] == "invited@example.com"
assert data["name"] == "Invited User"
def test_bulk_create_invited_users(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
AsyncMock(return_value=_sample_bulk_invited_users_result()),
)
response = client.post(
"/admin/invited-users/bulk",
files={
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
},
)
assert response.status_code == 200
data = response.json()
assert data["created_count"] == 1
assert data["skipped_count"] == 1
assert data["results"][0]["status"] == "CREATED"
assert data["results"][1]["status"] == "SKIPPED"
def test_revoke_invited_user(
mocker: pytest_mock.MockerFixture,
) -> None:
revoked = _sample_invited_user().model_copy(
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
)
mocker.patch(
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
AsyncMock(return_value=revoked),
)
response = client.post("/admin/invited-users/invite-1/revoke")
assert response.status_code == 200
assert response.json()["status"] == "REVOKED"
def test_retry_invited_user_tally(
mocker: pytest_mock.MockerFixture,
) -> None:
retried = _sample_invited_user().model_copy(
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
)
mocker.patch(
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
AsyncMock(return_value=retried),
)
response = client.post("/admin/invited-users/invite-1/retry-tally")
assert response.status_code == 200
assert response.json()["tally_status"] == "RUNNING"

View File

@@ -28,6 +28,7 @@ from backend.copilot.model import (
update_session_title,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
from backend.copilot.tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
@@ -265,12 +266,12 @@ async def delete_session(
)
# Best-effort cleanup of the E2B sandbox (if any).
config = ChatConfig()
if config.use_e2b_sandbox and config.e2b_api_key:
from backend.copilot.tools.e2b_sandbox import kill_sandbox
# sandbox_id is in Redis; kill_sandbox() fetches it from there.
e2b_cfg = ChatConfig()
if e2b_cfg.e2b_active:
assert e2b_cfg.e2b_api_key # guaranteed by e2b_active check
try:
await kill_sandbox(session_id, config.e2b_api_key)
await kill_sandbox(session_id, e2b_cfg.e2b_api_key)
except Exception:
logger.warning(
"[E2B] Failed to kill sandbox for session %s", session_id[:12]

View File

@@ -24,7 +24,7 @@ from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests, validate_url
from backend.util.request import HTTPClientError, Requests, validate_url_host
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -80,7 +80,7 @@ async def discover_tools(
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
@@ -167,7 +167,7 @@ async def mcp_oauth_login(
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
@@ -187,7 +187,7 @@ async def mcp_oauth_login(
# Validate the auth server URL from metadata to prevent SSRF.
try:
await validate_url(auth_server_url, trusted_origins=[])
await validate_url_host(auth_server_url)
except ValueError as e:
raise fastapi.HTTPException(
status_code=400,
@@ -234,7 +234,7 @@ async def mcp_oauth_login(
if registration_endpoint:
# Validate the registration endpoint to prevent SSRF via metadata.
try:
await validate_url(registration_endpoint, trusted_origins=[])
await validate_url_host(registration_endpoint)
except ValueError:
pass # Skip registration, fall back to default client_id
else:
@@ -429,7 +429,7 @@ async def mcp_store_token(
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")

View File

@@ -32,9 +32,9 @@ async def client():
@pytest.fixture(autouse=True)
def _bypass_ssrf_validation():
"""Bypass validate_url in all route tests (test URLs don't resolve)."""
"""Bypass validate_url_host in all route tests (test URLs don't resolve)."""
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
):
yield
@@ -521,12 +521,12 @@ class TestStoreToken:
class TestSSRFValidation:
"""Verify that validate_url is enforced on all endpoints."""
"""Verify that validate_url_host is enforced on all endpoints."""
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
@@ -541,7 +541,7 @@ class TestSSRFValidation:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked private IP"),
):
@@ -556,7 +556,7 @@ class TestSSRFValidation:
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):

View File

@@ -55,7 +55,6 @@ from backend.data.credit import (
set_auto_top_up,
)
from backend.data.graph import GraphSettings
from backend.data.invited_user import get_or_activate_user
from backend.data.model import CredentialsMetaInput, UserOnboarding
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.data.onboarding import (
@@ -71,6 +70,7 @@ from backend.data.onboarding import (
update_user_onboarding,
)
from backend.data.user import (
get_or_create_user,
get_user_by_id,
get_user_notification_preference,
update_user_email,
@@ -136,10 +136,12 @@ _tally_background_tasks: set[asyncio.Task] = set()
dependencies=[Security(requires_user)],
)
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
user = await get_or_activate_user(user_data)
user = await get_or_create_user(user_data)
# Fire-and-forget: backfill Tally understanding when invite pre-seeding did
# not produce a stored result before first activation.
# Fire-and-forget: populate business understanding from Tally form.
# We use created_at proximity instead of an is_new flag because
# get_or_create_user is cached — a separate is_new return value would be
# unreliable on repeated calls within the cache TTL.
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
if age_seconds < 30:
try:
@@ -163,11 +165,8 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
dependencies=[Security(requires_user)],
)
async def update_user_email_route(
user_id: Annotated[str, Security(get_user_id)],
email: str = Body(...),
user_data: dict = Security(get_jwt_payload),
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
) -> dict[str, str]:
await get_or_activate_user(user_data)
await update_user_email(user_id, email)
return {"email": email}
@@ -183,7 +182,7 @@ async def get_user_timezone_route(
user_data: dict = Security(get_jwt_payload),
) -> TimezoneResponse:
"""Get user timezone setting."""
user = await get_or_activate_user(user_data)
user = await get_or_create_user(user_data)
return TimezoneResponse(timezone=user.timezone)
@@ -194,12 +193,9 @@ async def get_user_timezone_route(
dependencies=[Security(requires_user)],
)
async def update_user_timezone_route(
user_id: Annotated[str, Security(get_user_id)],
request: UpdateTimezoneRequest,
user_data: dict = Security(get_jwt_payload),
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
) -> TimezoneResponse:
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
await get_or_activate_user(user_data)
user = await update_user_timezone(user_id, str(request.timezone))
return TimezoneResponse(timezone=user.timezone)
@@ -212,9 +208,7 @@ async def update_user_timezone_route(
)
async def get_preferences(
user_id: Annotated[str, Security(get_user_id)],
user_data: dict = Security(get_jwt_payload),
) -> NotificationPreference:
await get_or_activate_user(user_data)
preferences = await get_user_notification_preference(user_id)
return preferences
@@ -228,9 +222,7 @@ async def get_preferences(
async def update_preferences(
user_id: Annotated[str, Security(get_user_id)],
preferences: NotificationPreferenceDTO = Body(...),
user_data: dict = Security(get_jwt_payload),
) -> NotificationPreference:
await get_or_activate_user(user_data)
output = await update_user_notification_preference(user_id, preferences)
return output

View File

@@ -51,7 +51,7 @@ def test_get_or_create_user_route(
}
mocker.patch(
"backend.api.features.v1.get_or_activate_user",
"backend.api.features.v1.get_or_create_user",
return_value=mock_user,
)

View File

@@ -19,7 +19,6 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.store_admin_routes
import backend.api.features.admin.user_admin_routes
import backend.api.features.builder
import backend.api.features.builder.routes
import backend.api.features.chat.routes as chat_routes
@@ -38,8 +37,10 @@ import backend.api.features.workspace.routes as workspace_routes
import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.llm_registry
import backend.data.user
import backend.integrations.webhooks.utils
import backend.server.v2.llm
import backend.util.service
import backend.util.settings
from backend.api.features.library.exceptions import (
@@ -118,11 +119,30 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Refresh LLM registry before initializing blocks so blocks can use registry data
# Note: Graceful fallback for now since no blocks consume registry yet (comes in PR #5)
# When block integration lands, this should fail hard or skip block initialization
try:
await backend.data.llm_registry.refresh_llm_registry()
logger.info("LLM registry refreshed successfully at startup")
except Exception as e:
logger.warning(
f"Failed to refresh LLM registry at startup: {e}. "
"Blocks will initialize with empty registry."
)
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
try:
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
except Exception as e:
logger.warning(
f"Failed to migrate LLM models at startup: {e}. "
"This is expected in test environments without AgentNode table."
)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
with launch_darkly_context():
@@ -312,11 +332,6 @@ app.include_router(
tags=["v2", "admin"],
prefix="/api/executions",
)
app.include_router(
backend.api.features.admin.user_admin_routes.router,
tags=["v2", "admin"],
prefix="/api/users",
)
app.include_router(
backend.api.features.executions.review.routes.router,
tags=["v2", "executions", "review"],
@@ -354,6 +369,11 @@ app.include_router(
tags=["oauth"],
prefix="/api/oauth",
)
app.include_router(
backend.server.v2.llm.router,
tags=["v2", "llm"],
prefix="/api",
)
app.mount("/external-api", external_api)

View File

@@ -418,6 +418,8 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
_optimized_description: ClassVar[str | None] = None
def __init__(
self,
id: str = "",
@@ -470,6 +472,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.block_type = block_type
self.webhook_config = webhook_config
self.is_sensitive_action = is_sensitive_action
# Read from ClassVar set by initialize_blocks()
self.optimized_description: str | None = type(self)._optimized_description
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
if self.webhook_config:

View File

@@ -142,7 +142,7 @@ class BaseE2BExecutorMixin:
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
# Execute the code
execution = await sandbox.run_code(
execution = await sandbox.run_code( # type: ignore[attr-defined]
code,
language=language.value,
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error

View File

@@ -17,7 +17,7 @@ from backend.blocks.jina._auth import (
from backend.blocks.search import GetRequest
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host
class SearchTheWebBlock(Block, GetRequest):
@@ -112,7 +112,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
) -> BlockOutput:
if input_data.raw_content:
try:
parsed_url, _, _ = await validate_url(input_data.url, [])
parsed_url, _, _ = await validate_url_host(input_data.url)
url = parsed_url.geturl()
except ValueError as e:
yield "error", f"Invalid URL: {e}"

View File

@@ -31,10 +31,14 @@ from backend.data.model import (
)
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.clients import OPENROUTER_BASE_URL
from backend.util.logging import TruncatedLogger
from backend.util.prompt import compress_context, estimate_token_count
from backend.util.request import validate_url_host
from backend.util.settings import Settings
from backend.util.text import TextFormatter
settings = Settings()
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
fmt = TextFormatter(autoescape=False)
@@ -804,6 +808,11 @@ async def llm_call(
if tools:
raise ValueError("Ollama does not support tools.")
# Validate user-provided Ollama host to prevent SSRF etc.
await validate_url_host(
ollama_host, trusted_hostnames=[settings.config.ollama_host]
)
client = ollama.AsyncClient(host=ollama_host)
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
@@ -825,7 +834,7 @@ async def llm_call(
elif provider == "open_router":
tools_param = tools if tools else openai.NOT_GIVEN
client = openai.AsyncOpenAI(
base_url="https://openrouter.ai/api/v1",
base_url=OPENROUTER_BASE_URL,
api_key=credentials.api_key.get_secret_value(),
)

View File

@@ -21,6 +21,7 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.clients import OPENROUTER_BASE_URL
from backend.util.logging import TruncatedLogger
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
@@ -136,7 +137,7 @@ class PerplexityBlock(Block):
) -> dict[str, Any]:
"""Call Perplexity via OpenRouter and extract annotations."""
client = openai.AsyncOpenAI(
base_url="https://openrouter.ai/api/v1",
base_url=OPENROUTER_BASE_URL,
api_key=credentials.api_key.get_secret_value(),
)

View File

@@ -1,10 +1,13 @@
"""Configuration management for chat system."""
import os
from typing import Literal
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
from backend.util.clients import OPENROUTER_BASE_URL
class ChatConfig(BaseSettings):
"""Configuration for the chat system."""
@@ -19,7 +22,7 @@ class ChatConfig(BaseSettings):
)
api_key: str | None = Field(default=None, description="OpenAI API key")
base_url: str | None = Field(
default="https://openrouter.ai/api/v1",
default=OPENROUTER_BASE_URL,
description="Base URL for API (e.g., for OpenRouter)",
)
@@ -112,9 +115,37 @@ class ChatConfig(BaseSettings):
description="E2B sandbox template to use for copilot sessions.",
)
e2b_sandbox_timeout: int = Field(
default=43200, # 12 hours — same as session_ttl
description="E2B sandbox keepalive timeout in seconds.",
default=10800, # 3 hours — wall-clock timeout, not idle; explicit pause is primary
description="E2B sandbox running-time timeout (seconds). "
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
"mechanism; this is the safety net.",
)
e2b_sandbox_on_timeout: Literal["kill", "pause"] = Field(
default="pause",
description="E2B lifecycle action on timeout: 'pause' (default, free) or 'kill'.",
)
@property
def e2b_active(self) -> bool:
"""True when E2B is enabled and the API key is present.
Single source of truth for "should we use E2B right now?".
Prefer this over combining ``use_e2b_sandbox`` and ``e2b_api_key``
separately at call sites.
"""
return self.use_e2b_sandbox and bool(self.e2b_api_key)
@property
def active_e2b_api_key(self) -> str | None:
"""Return the E2B API key when E2B is enabled and configured, else None.
Combines the ``use_e2b_sandbox`` flag check and key presence into one.
Use in callers::
if api_key := config.active_e2b_api_key:
# E2B is active; api_key is narrowed to str
"""
return self.e2b_api_key if self.e2b_active else None
@field_validator("use_e2b_sandbox", mode="before")
@classmethod
@@ -164,7 +195,7 @@ class ChatConfig(BaseSettings):
if not v:
v = os.getenv("OPENAI_BASE_URL")
if not v:
v = "https://openrouter.ai/api/v1"
v = OPENROUTER_BASE_URL
return v
@field_validator("use_claude_agent_sdk", mode="before")

View File

@@ -0,0 +1,38 @@
"""Unit tests for ChatConfig."""
import pytest
from .config import ChatConfig
# Env vars that the ChatConfig validators read — must be cleared so they don't
# override the explicit constructor values we pass in each test.
_E2B_ENV_VARS = (
"CHAT_USE_E2B_SANDBOX",
"CHAT_E2B_API_KEY",
"E2B_API_KEY",
)
@pytest.fixture(autouse=True)
def _clean_e2b_env(monkeypatch: pytest.MonkeyPatch) -> None:
for var in _E2B_ENV_VARS:
monkeypatch.delenv(var, raising=False)
class TestE2BActive:
"""Tests for the e2b_active property — single source of truth for E2B usage."""
def test_both_enabled_and_key_present_returns_true(self):
"""e2b_active is True when use_e2b_sandbox=True and e2b_api_key is set."""
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key="test-key")
assert cfg.e2b_active is True
def test_enabled_but_missing_key_returns_false(self):
"""e2b_active is False when use_e2b_sandbox=True but e2b_api_key is absent."""
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key=None)
assert cfg.e2b_active is False
def test_disabled_returns_false(self):
"""e2b_active is False when use_e2b_sandbox=False regardless of key."""
cfg = ChatConfig(use_e2b_sandbox=False, e2b_api_key="test-key")
assert cfg.e2b_active is False

View File

@@ -0,0 +1,115 @@
"""Shared execution context for copilot SDK tool handlers.
All context variables and their accessors live here so that
``tool_adapter``, ``file_ref``, and ``e2b_file_tools`` can import them
without creating circular dependencies.
"""
import os
import re
from contextvars import ContextVar
from typing import TYPE_CHECKING
from backend.copilot.model import ChatSession
if TYPE_CHECKING:
from e2b import AsyncSandbox
# Allowed base directory for the Read tool.
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# Encoded project-directory name for the current session (e.g.
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
# validation can scope tool-results reads to the current session.
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
"_current_sandbox", default=None
)
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does."""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
def set_execution_context(
user_id: str | None,
session: ChatSession,
sandbox: "AsyncSandbox | None" = None,
sdk_cwd: str | None = None,
) -> None:
"""Set per-turn context variables used by file-resolution tool handlers."""
_current_user_id.set(user_id)
_current_session.set(session)
_current_sandbox.set(sandbox)
_current_sdk_cwd.set(sdk_cwd or "")
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
def get_execution_context() -> tuple[str | None, ChatSession | None]:
"""Return the current (user_id, session) pair for the active request."""
return _current_user_id.get(), _current_session.get()
def get_current_sandbox() -> "AsyncSandbox | None":
"""Return the E2B sandbox for the current session, or None if not active."""
return _current_sandbox.get()
def get_sdk_cwd() -> str:
"""Return the SDK working directory for the current session (empty string if unset)."""
return _current_sdk_cwd.get()
E2B_WORKDIR = "/home/user"
def resolve_sandbox_path(path: str) -> str:
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
Raises :class:`ValueError` if the resolved path escapes the sandbox.
"""
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
normalized = os.path.normpath(candidate)
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
return normalized
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
"""Return True if *path* is within an allowed host-filesystem location.
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
"""
if not path:
return False
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
return True
encoded = _current_project_dir.get("")
if encoded:
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
if resolved == tool_results_dir or resolved.startswith(
tool_results_dir + os.sep
):
return True
return False

View File

@@ -0,0 +1,163 @@
"""Tests for context.py — execution context variables and path helpers."""
from __future__ import annotations
import os
import tempfile
from unittest.mock import MagicMock
import pytest
from backend.copilot.context import (
_SDK_PROJECTS_DIR,
_current_project_dir,
get_current_sandbox,
get_execution_context,
get_sdk_cwd,
is_allowed_local_path,
resolve_sandbox_path,
set_execution_context,
)
def _make_session() -> MagicMock:
s = MagicMock()
s.session_id = "test-session"
return s
# ---------------------------------------------------------------------------
# Context variable getters
# ---------------------------------------------------------------------------
def test_get_execution_context_defaults():
"""get_execution_context returns (None, session) when user_id is not set."""
set_execution_context(None, _make_session())
user_id, session = get_execution_context()
assert user_id is None
assert session is not None
def test_set_and_get_execution_context():
"""set_execution_context stores user_id and session."""
mock_session = _make_session()
set_execution_context("user-abc", mock_session)
user_id, session = get_execution_context()
assert user_id == "user-abc"
assert session is mock_session
def test_get_current_sandbox_none_by_default():
"""get_current_sandbox returns None when no sandbox is set."""
set_execution_context("u1", _make_session(), sandbox=None)
assert get_current_sandbox() is None
def test_get_current_sandbox_returns_set_value():
"""get_current_sandbox returns the sandbox set via set_execution_context."""
mock_sandbox = MagicMock()
set_execution_context("u1", _make_session(), sandbox=mock_sandbox)
assert get_current_sandbox() is mock_sandbox
def test_get_sdk_cwd_empty_when_not_set():
"""get_sdk_cwd returns empty string when sdk_cwd is not set."""
set_execution_context("u1", _make_session(), sdk_cwd=None)
assert get_sdk_cwd() == ""
def test_get_sdk_cwd_returns_set_value():
"""get_sdk_cwd returns the value set via set_execution_context."""
set_execution_context("u1", _make_session(), sdk_cwd="/tmp/copilot-test")
assert get_sdk_cwd() == "/tmp/copilot-test"
# ---------------------------------------------------------------------------
# is_allowed_local_path
# ---------------------------------------------------------------------------
def test_is_allowed_local_path_empty():
assert not is_allowed_local_path("")
def test_is_allowed_local_path_inside_sdk_cwd():
with tempfile.TemporaryDirectory() as cwd:
path = os.path.join(cwd, "file.txt")
assert is_allowed_local_path(path, cwd)
def test_is_allowed_local_path_sdk_cwd_itself():
with tempfile.TemporaryDirectory() as cwd:
assert is_allowed_local_path(cwd, cwd)
def test_is_allowed_local_path_outside_sdk_cwd():
with tempfile.TemporaryDirectory() as cwd:
assert not is_allowed_local_path("/etc/passwd", cwd)
def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
"""Without sdk_cwd or project_dir, all paths are rejected."""
_current_project_dir.set("")
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
def test_is_allowed_local_path_tool_results_dir():
"""Files under the tool-results directory for the current project are allowed."""
encoded = "test-encoded-dir"
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
path = os.path.join(tool_results_dir, "output.txt")
_current_project_dir.set(encoded)
try:
assert is_allowed_local_path(path, sdk_cwd=None)
finally:
_current_project_dir.set("")
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
"""A path adjacent to tool-results/ but not inside it is rejected."""
encoded = "test-encoded-dir"
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
_current_project_dir.set(encoded)
try:
assert not is_allowed_local_path(sibling_path, sdk_cwd=None)
finally:
_current_project_dir.set("")
# ---------------------------------------------------------------------------
# resolve_sandbox_path
# ---------------------------------------------------------------------------
def test_resolve_sandbox_path_absolute_valid():
assert (
resolve_sandbox_path("/home/user/project/main.py")
== "/home/user/project/main.py"
)
def test_resolve_sandbox_path_relative():
assert resolve_sandbox_path("project/main.py") == "/home/user/project/main.py"
def test_resolve_sandbox_path_workdir_itself():
assert resolve_sandbox_path("/home/user") == "/home/user"
def test_resolve_sandbox_path_normalizes_dots():
assert resolve_sandbox_path("/home/user/a/../b") == "/home/user/b"
def test_resolve_sandbox_path_escape_raises():
with pytest.raises(ValueError, match="/home/user"):
resolve_sandbox_path("/home/user/../../etc/passwd")
def test_resolve_sandbox_path_absolute_outside_raises():
with pytest.raises(ValueError, match="/home/user"):
resolve_sandbox_path("/etc/passwd")

View File

@@ -0,0 +1,138 @@
"""Scheduler job to generate LLM-optimized block descriptions.
Runs periodically to rewrite block descriptions into concise, actionable
summaries that help the copilot LLM pick the right blocks during agent
generation.
"""
import asyncio
import logging
from backend.blocks import get_blocks
from backend.util.clients import get_database_manager_client, get_openai_client
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = (
"You are a technical writer for an automation platform. "
"Rewrite the following block description to be concise (under 50 words), "
"informative, and actionable. Focus on what the block does and when to "
"use it. Output ONLY the rewritten description, nothing else. "
"Do not use markdown formatting."
)
# Rate-limit delay between sequential LLM calls (seconds)
_RATE_LIMIT_DELAY = 0.5
# Maximum tokens for optimized description generation
_MAX_DESCRIPTION_TOKENS = 150
# Model for generating optimized descriptions (fast, cheap)
_MODEL = "gpt-4o-mini"
async def _optimize_descriptions(blocks: list[dict[str, str]]) -> dict[str, str]:
"""Call the shared OpenAI client to rewrite each block description."""
client = get_openai_client()
if client is None:
logger.error(
"No OpenAI client configured, skipping block description optimization"
)
return {}
results: dict[str, str] = {}
for block in blocks:
block_id = block["id"]
block_name = block["name"]
description = block["description"]
try:
response = await client.chat.completions.create(
model=_MODEL,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": f"Block name: {block_name}\nDescription: {description}",
},
],
max_tokens=_MAX_DESCRIPTION_TOKENS,
)
optimized = (response.choices[0].message.content or "").strip()
if optimized:
results[block_id] = optimized
logger.debug("Optimized description for %s", block_name)
else:
logger.warning("Empty response for block %s", block_name)
except Exception:
logger.warning(
"Failed to optimize description for %s", block_name, exc_info=True
)
await asyncio.sleep(_RATE_LIMIT_DELAY)
return results
def optimize_block_descriptions() -> dict[str, int]:
"""Generate optimized descriptions for blocks that don't have one yet.
Uses the shared OpenAI client to rewrite block descriptions into concise
summaries suitable for agent generation prompts.
Returns:
Dict with counts: processed, success, failed, skipped.
"""
db_client = get_database_manager_client()
blocks = db_client.get_blocks_needing_optimization()
if not blocks:
logger.info("All blocks already have optimized descriptions")
return {"processed": 0, "success": 0, "failed": 0, "skipped": 0}
logger.info("Found %d blocks needing optimized descriptions", len(blocks))
non_empty = [b for b in blocks if b.get("description", "").strip()]
skipped = len(blocks) - len(non_empty)
new_descriptions = asyncio.run(_optimize_descriptions(non_empty))
stats = {
"processed": len(non_empty),
"success": len(new_descriptions),
"failed": len(non_empty) - len(new_descriptions),
"skipped": skipped,
}
logger.info(
"Block description optimization complete: "
"%d/%d succeeded, %d failed, %d skipped",
stats["success"],
stats["processed"],
stats["failed"],
stats["skipped"],
)
if new_descriptions:
for block_id, optimized in new_descriptions.items():
db_client.update_block_optimized_description(block_id, optimized)
# Update in-memory descriptions first so the cache rebuilds with fresh data.
try:
block_classes = get_blocks()
for block_id, optimized in new_descriptions.items():
if block_id in block_classes:
block_classes[block_id]._optimized_description = optimized
logger.info(
"Updated %d in-memory block descriptions", len(new_descriptions)
)
except Exception:
logger.warning(
"Could not update in-memory block descriptions", exc_info=True
)
from backend.copilot.tools.agent_generator.blocks import (
reset_block_caches, # local to avoid circular import
)
reset_block_caches()
return stats

View File

@@ -0,0 +1,91 @@
"""Unit tests for optimize_blocks._optimize_descriptions."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from backend.copilot.optimize_blocks import _RATE_LIMIT_DELAY, _optimize_descriptions
def _make_client_response(text: str) -> MagicMock:
"""Build a minimal mock that looks like an OpenAI ChatCompletion response."""
choice = MagicMock()
choice.message.content = text
response = MagicMock()
response.choices = [choice]
return response
def _run(coro):
return asyncio.get_event_loop().run_until_complete(coro)
class TestOptimizeDescriptions:
"""Tests for _optimize_descriptions async function."""
def test_returns_empty_when_no_client(self):
with patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=None
):
result = _run(
_optimize_descriptions([{"id": "b1", "name": "B", "description": "d"}])
)
assert result == {}
def test_success_single_block(self):
client = MagicMock()
client.chat.completions.create = AsyncMock(
return_value=_make_client_response("Short desc.")
)
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
with (
patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
),
patch(
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
),
):
result = _run(_optimize_descriptions(blocks))
assert result == {"b1": "Short desc."}
client.chat.completions.create.assert_called_once()
def test_skips_block_on_exception(self):
client = MagicMock()
client.chat.completions.create = AsyncMock(side_effect=Exception("API error"))
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
with (
patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
),
patch(
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
),
):
result = _run(_optimize_descriptions(blocks))
assert result == {}
def test_sleeps_between_blocks(self):
client = MagicMock()
client.chat.completions.create = AsyncMock(
return_value=_make_client_response("desc")
)
blocks = [
{"id": "b1", "name": "B1", "description": "d1"},
{"id": "b2", "name": "B2", "description": "d2"},
]
sleep_mock = AsyncMock()
with (
patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
),
patch("backend.copilot.optimize_blocks.asyncio.sleep", sleep_mock),
):
_run(_optimize_descriptions(blocks))
assert sleep_mock.call_count == 2
sleep_mock.assert_called_with(_RATE_LIMIT_DELAY)

View File

@@ -26,6 +26,33 @@ your message as a Markdown link or image:
The `download_url` field in the `write_workspace_file` response is already
in the correct format — paste it directly after the `(` in the Markdown.
### Passing file content to tools — @@agptfile: references
Instead of copying large file contents into a tool argument, pass a file
reference and the platform will load the content for you.
Syntax: `@@agptfile:<uri>[<start>-<end>]`
- `<uri>` **must** start with `workspace://` or `/` (absolute path):
- `workspace://<file_id>` — workspace file by ID
- `workspace:///<path>` — workspace file by virtual path
- `/absolute/local/path` — ephemeral or sdk_cwd file
- E2B sandbox absolute path (e.g. `/home/user/script.py`)
- `[<start>-<end>]` is an optional 1-indexed inclusive line range.
- URIs that do not start with `workspace://` or `/` are **not** expanded.
Examples:
```
@@agptfile:workspace://abc123
@@agptfile:workspace://abc123[10-50]
@@agptfile:workspace:///reports/q1.md
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
@@agptfile:/home/user/script.py
```
You can embed a reference inside any string argument, or use it as the entire
value. Multiple references in one argument are all expanded.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.

View File

@@ -0,0 +1,155 @@
## Agent Generation Guide
You can create, edit, and customize agents directly. You ARE the brain —
generate the agent JSON yourself using block schemas, then validate and save.
### Workflow for Creating/Editing Agents
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
search for relevant blocks. This returns block IDs, names, descriptions,
and full input/output schemas.
2. **Find library agents**: Call `find_library_agent` to discover reusable
agents that can be composed as sub-agents via `AgentExecutorBlock`.
3. **Generate JSON**: Build the agent JSON using block schemas:
- Use block IDs from step 1 as `block_id` in nodes
- Wire outputs to inputs using links
- Set design-time config in `input_default`
- Use `AgentInputBlock` for values the user provides at runtime
4. **Write to workspace**: Save the JSON to a workspace file so the user
can review it: `write_workspace_file(filename="agent.json", content=...)`
5. **Validate**: Call `validate_agent_graph` with the agent JSON to check
for errors
6. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
or fix manually based on the error descriptions. Iterate until valid.
7. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
the final `agent_json`
### Agent JSON Structure
```json
{
"id": "<UUID v4>", // auto-generated if omitted
"version": 1,
"is_active": true,
"name": "Agent Name",
"description": "What the agent does",
"nodes": [
{
"id": "<UUID v4>",
"block_id": "<block UUID from find_block>",
"input_default": {
"field_name": "design-time value"
},
"metadata": {
"position": {"x": 0, "y": 0},
"customized_name": "Optional display name"
}
}
],
"links": [
{
"id": "<UUID v4>",
"source_id": "<source node UUID>",
"source_name": "output_field_name",
"sink_id": "<sink node UUID>",
"sink_name": "input_field_name",
"is_static": false
}
]
}
```
### REQUIRED: AgentInputBlock and AgentOutputBlock
Every agent MUST include at least one AgentInputBlock and one AgentOutputBlock.
These define the agent's interface — what it accepts and what it produces.
**AgentInputBlock** (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`):
- Defines a user-facing input field on the agent
- Required `input_default` fields: `name` (str), `value` (default: null)
- Optional: `title`, `description`, `placeholder_values` (for dropdowns)
- Output: `result` — the user-provided value at runtime
- Create one AgentInputBlock per distinct input the agent needs
**AgentOutputBlock** (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`):
- Defines a user-facing output displayed after the agent runs
- Required `input_default` fields: `name` (str)
- The `value` input should be linked from another block's output
- Optional: `title`, `description`, `format` (Jinja2 template)
- Create one AgentOutputBlock per distinct result to show the user
Without these blocks, the agent has no interface and the user cannot provide
inputs or see outputs. NEVER skip them.
### Key Rules
- **Name & description**: Include `name` and `description` in the agent JSON
when creating a new agent, or when editing and the agent's purpose changed.
Without these the agent gets a generic default name.
- **Design-time vs runtime**: `input_default` = values known at build time.
For user-provided values, create an `AgentInputBlock` node and link its
output to the consuming block's input.
- **Credentials**: Do NOT require credentials upfront. Users configure
credentials later in the platform UI after the agent is saved.
- **Node spacing**: Position nodes with at least 800 X-units between them.
- **Nested properties**: Use `parentField_#_childField` notation in link
sink_name/source_name to access nested object fields.
- **is_static links**: Set `is_static: true` when the link carries a
design-time constant (matches a field in inputSchema with a default).
- **ConditionBlock**: Needs a `StoreValueBlock` wired to its `value2` input.
- **Prompt templates**: Use `{{variable}}` (double curly braces) for
literal braces in prompt strings — single `{` and `}` are for
template variables.
- **AgentExecutorBlock**: When composing sub-agents, set `graph_id` and
`graph_version` in input_default, and wire inputs/outputs to match
the sub-agent's schema.
### Using Sub-Agents (AgentExecutorBlock)
To compose agents using other agents as sub-agents:
1. Call `find_library_agent` to find the sub-agent — the response includes
`graph_id`, `graph_version`, `input_schema`, and `output_schema`
2. Create an `AgentExecutorBlock` node (ID: `e189baac-8c20-45a1-94a7-55177ea42565`)
3. Set `input_default`:
- `graph_id`: from the library agent's `graph_id`
- `graph_version`: from the library agent's `graph_version`
- `input_schema`: from the library agent's `input_schema` (JSON Schema)
- `output_schema`: from the library agent's `output_schema` (JSON Schema)
- `user_id`: leave as `""` (filled at runtime)
- `inputs`: `{}` (populated by links at runtime)
4. Wire inputs: link to sink names matching the sub-agent's `input_schema`
property names (e.g., if input_schema has a `"url"` property, use
`"url"` as the sink_name)
5. Wire outputs: link from source names matching the sub-agent's
`output_schema` property names
6. Pass `library_agent_ids` to `create_agent`/`customize_agent` with
the library agent IDs used, so the fixer can validate schemas
### Using MCP Tools (MCPToolBlock)
To use an MCP (Model Context Protocol) tool as a node in the agent:
1. The user must specify which MCP server URL and tool name they want
2. Create an `MCPToolBlock` node (ID: `a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4`)
3. Set `input_default`:
- `server_url`: the MCP server URL (e.g. `"https://mcp.example.com/sse"`)
- `selected_tool`: the tool name on that server
- `tool_input_schema`: JSON Schema for the tool's inputs
- `tool_arguments`: `{}` (populated by links or hardcoded values)
4. The block requires MCP credentials — the user configures these in the
platform UI after the agent is saved
5. Wire inputs using the tool argument field name directly as the sink_name
(e.g., `query`, NOT `tool_arguments_#_query`). The execution engine
automatically collects top-level fields matching tool_input_schema into
tool_arguments.
6. Output: `result` (the tool's return value) and `error` (error message)
### Example: Simple AI Text Processor
A minimal agent with input, processing, and output:
- Node 1: `AgentInputBlock` (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`,
input_default: {"name": "user_text", "title": "Text to process"},
output: "result")
- Node 2: `AITextGeneratorBlock` (input: "prompt" linked from Node 1's "result")
- Node 3: `AgentOutputBlock` (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`,
input_default: {"name": "summary", "title": "Summary"},
input: "value" linked from Node 2's output)

View File

@@ -8,8 +8,6 @@ SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
"""
from __future__ import annotations
import itertools
import json
import logging
@@ -17,36 +15,23 @@ import os
import shlex
from typing import Any, Callable
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
from backend.copilot.context import (
E2B_WORKDIR,
get_current_sandbox,
get_sdk_cwd,
is_allowed_local_path,
resolve_sandbox_path,
)
logger = logging.getLogger(__name__)
# Lazy imports to break circular dependency with tool_adapter.
def _get_sandbox(): # type: ignore[return]
from .tool_adapter import get_current_sandbox # noqa: E402
def _get_sandbox():
return get_current_sandbox()
def _is_allowed_local(path: str) -> bool:
from .tool_adapter import is_allowed_local_path # noqa: E402
return is_allowed_local_path(path)
def _resolve_remote(path: str) -> str:
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
Raises :class:`ValueError` if the resolved path escapes the sandbox.
"""
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
normalized = os.path.normpath(candidate)
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
return normalized
return is_allowed_local_path(path, get_sdk_cwd())
def _mcp(text: str, *, error: bool = False) -> dict[str, Any]:
@@ -63,7 +48,7 @@ def _get_sandbox_and_path(
if sandbox is None:
return _mcp("No E2B sandbox available", error=True)
try:
remote = _resolve_remote(file_path)
remote = resolve_sandbox_path(file_path)
except ValueError as exc:
return _mcp(str(exc), error=True)
return sandbox, remote
@@ -73,6 +58,7 @@ def _get_sandbox_and_path(
async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
"""Read lines from a sandbox file, falling back to the local host for SDK-internal paths."""
file_path: str = args.get("file_path", "")
offset: int = max(0, int(args.get("offset", 0)))
limit: int = max(1, int(args.get("limit", 2000)))
@@ -104,6 +90,7 @@ async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
"""Write content to a sandbox file, creating parent directories as needed."""
file_path: str = args.get("file_path", "")
content: str = args.get("content", "")
@@ -127,6 +114,7 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
"""Replace a substring in a sandbox file, with optional replace-all support."""
file_path: str = args.get("file_path", "")
old_string: str = args.get("old_string", "")
new_string: str = args.get("new_string", "")
@@ -172,6 +160,7 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
"""Find files matching a name pattern inside the sandbox using ``find``."""
pattern: str = args.get("pattern", "")
path: str = args.get("path", "")
@@ -183,7 +172,7 @@ async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
return _mcp("No E2B sandbox available", error=True)
try:
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
except ValueError as exc:
return _mcp(str(exc), error=True)
@@ -198,6 +187,7 @@ async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
"""Search file contents by regex inside the sandbox using ``grep -rn``."""
pattern: str = args.get("pattern", "")
path: str = args.get("path", "")
include: str = args.get("include", "")
@@ -210,7 +200,7 @@ async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
return _mcp("No E2B sandbox available", error=True)
try:
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
except ValueError as exc:
return _mcp(str(exc), error=True)
@@ -238,7 +228,7 @@ def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
return _mcp(f"Path not allowed: {file_path}", error=True)
expanded = os.path.realpath(os.path.expanduser(file_path))
try:
with open(expanded) as fh:
with open(expanded, encoding="utf-8", errors="replace") as fh:
selected = list(itertools.islice(fh, offset, offset + limit))
numbered = "".join(
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)

View File

@@ -7,59 +7,60 @@ import os
import pytest
from .e2b_file_tools import _read_local, _resolve_remote
from .tool_adapter import _current_project_dir
from backend.copilot.context import _current_project_dir
from .e2b_file_tools import _read_local, resolve_sandbox_path
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# ---------------------------------------------------------------------------
# _resolve_remote — sandbox path normalisation & boundary enforcement
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
# ---------------------------------------------------------------------------
class TestResolveRemote:
class TestResolveSandboxPath:
def test_relative_path_resolved(self):
assert _resolve_remote("src/main.py") == "/home/user/src/main.py"
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
def test_absolute_within_sandbox(self):
assert _resolve_remote("/home/user/file.txt") == "/home/user/file.txt"
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
def test_workdir_itself(self):
assert _resolve_remote("/home/user") == "/home/user"
assert resolve_sandbox_path("/home/user") == "/home/user"
def test_relative_dotslash(self):
assert _resolve_remote("./README.md") == "/home/user/README.md"
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
def test_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("../../etc/passwd")
resolve_sandbox_path("../../etc/passwd")
def test_absolute_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/home/user/../../etc/passwd")
resolve_sandbox_path("/home/user/../../etc/passwd")
def test_absolute_outside_sandbox_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/etc/passwd")
resolve_sandbox_path("/etc/passwd")
def test_root_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/")
resolve_sandbox_path("/")
def test_home_other_user_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/home/other/file.txt")
resolve_sandbox_path("/home/other/file.txt")
def test_deep_nested_allowed(self):
assert _resolve_remote("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
def test_trailing_slash_normalised(self):
assert _resolve_remote("src/") == "/home/user/src"
assert resolve_sandbox_path("src/") == "/home/user/src"
def test_double_dots_within_sandbox_ok(self):
"""Path that resolves back within /home/user is allowed."""
assert _resolve_remote("a/b/../c.txt") == "/home/user/a/c.txt"
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
# ---------------------------------------------------------------------------

View File

@@ -0,0 +1,281 @@
"""File reference protocol for tool call inputs.
Allows the LLM to pass a file reference instead of embedding large content
inline. The processor expands ``@@agptfile:<uri>[<start>-<end>]`` tokens in tool
arguments before the tool is executed.
Protocol
--------
@@agptfile:<uri>[<start>-<end>]
``<uri>`` (required)
- ``workspace://<file_id>`` — workspace file by ID
- ``workspace://<file_id>#<mime>`` — same, MIME hint is ignored for reads
- ``workspace:///<path>`` — workspace file by virtual path
- ``/absolute/local/path`` — ephemeral or sdk_cwd file (validated by
:func:`~backend.copilot.sdk.tool_adapter.is_allowed_local_path`)
- Any absolute path that resolves inside the E2B sandbox
(``/home/user/...``) when a sandbox is active
``[<start>-<end>]`` (optional)
Line range, 1-indexed inclusive. Examples: ``[1-100]``, ``[50-200]``.
Omit to read the entire file.
Examples
--------
@@agptfile:workspace://abc123
@@agptfile:workspace://abc123[10-50]
@@agptfile:workspace:///reports/q1.md
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
@@agptfile:/home/user/script.sh
"""
import itertools
import logging
import os
import re
from dataclasses import dataclass
from typing import Any
from backend.copilot.context import (
get_current_sandbox,
get_sdk_cwd,
is_allowed_local_path,
resolve_sandbox_path,
)
from backend.copilot.model import ChatSession
from backend.copilot.tools.workspace_files import get_manager
from backend.util.file import parse_workspace_uri
class FileRefExpansionError(Exception):
"""Raised when a ``@@agptfile:`` reference in tool call args fails to resolve.
Separating this from inline substitution lets callers (e.g. the MCP tool
wrapper) block tool execution and surface a helpful error to the model
rather than passing an ``[file-ref error: …]`` string as actual input.
"""
logger = logging.getLogger(__name__)
FILE_REF_PREFIX = "@@agptfile:"
# Matches: @@agptfile:<uri>[start-end]?
# Group 1 URI; must start with '/' (absolute path) or 'workspace://'
# Group 2 start line (optional)
# Group 3 end line (optional)
_FILE_REF_RE = re.compile(
re.escape(FILE_REF_PREFIX) + r"((?:workspace://|/)[^\[\s]*)(?:\[(\d+)-(\d+)\])?"
)
# Maximum characters returned for a single file reference expansion.
_MAX_EXPAND_CHARS = 200_000
# Maximum total characters across all @@agptfile: expansions in one string.
_MAX_TOTAL_EXPAND_CHARS = 1_000_000
@dataclass
class FileRef:
uri: str
start_line: int | None # 1-indexed, inclusive
end_line: int | None # 1-indexed, inclusive
def parse_file_ref(text: str) -> FileRef | None:
"""Return a :class:`FileRef` if *text* is a bare file reference token.
A "bare token" means the entire string matches the ``@@agptfile:...`` pattern
(after stripping whitespace). Use :func:`expand_file_refs_in_string` to
expand references embedded in larger strings.
"""
m = _FILE_REF_RE.fullmatch(text.strip())
if not m:
return None
start = int(m.group(2)) if m.group(2) else None
end = int(m.group(3)) if m.group(3) else None
if start is not None and start < 1:
return None
if end is not None and end < 1:
return None
if start is not None and end is not None and end < start:
return None
return FileRef(uri=m.group(1), start_line=start, end_line=end)
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
"""Slice *text* to the requested 1-indexed line range (inclusive)."""
if start is None and end is None:
return text
lines = text.splitlines(keepends=True)
s = (start - 1) if start is not None else 0
e = end if end is not None else len(lines)
selected = list(itertools.islice(lines, s, e))
return "".join(selected)
async def read_file_bytes(
uri: str,
user_id: str | None,
session: ChatSession,
) -> bytes:
"""Resolve *uri* to raw bytes using workspace, local, or E2B path logic.
Raises :class:`ValueError` if the URI cannot be resolved.
"""
# Strip MIME fragment (e.g. workspace://id#mime) before dispatching.
plain = uri.split("#")[0] if uri.startswith("workspace://") else uri
if plain.startswith("workspace://"):
if not user_id:
raise ValueError("workspace:// file references require authentication")
manager = await get_manager(user_id, session.session_id)
ws = parse_workspace_uri(plain)
try:
return await (
manager.read_file(ws.file_ref)
if ws.is_path
else manager.read_file_by_id(ws.file_ref)
)
except FileNotFoundError:
raise ValueError(f"File not found: {plain}")
except Exception as exc:
raise ValueError(f"Failed to read {plain}: {exc}") from exc
if is_allowed_local_path(plain, get_sdk_cwd()):
resolved = os.path.realpath(os.path.expanduser(plain))
try:
with open(resolved, "rb") as fh:
return fh.read()
except FileNotFoundError:
raise ValueError(f"File not found: {plain}")
except Exception as exc:
raise ValueError(f"Failed to read {plain}: {exc}") from exc
sandbox = get_current_sandbox()
if sandbox is not None:
try:
remote = resolve_sandbox_path(plain)
except ValueError as exc:
raise ValueError(
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
) from exc
try:
return bytes(await sandbox.files.read(remote, format="bytes"))
except Exception as exc:
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
raise ValueError(
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
)
async def resolve_file_ref(
ref: FileRef,
user_id: str | None,
session: ChatSession,
) -> str:
"""Resolve a :class:`FileRef` to its text content."""
raw = await read_file_bytes(ref.uri, user_id, session)
return _apply_line_range(
raw.decode("utf-8", errors="replace"), ref.start_line, ref.end_line
)
async def expand_file_refs_in_string(
text: str,
user_id: str | None,
session: "ChatSession",
*,
raise_on_error: bool = False,
) -> str:
"""Expand all ``@@agptfile:...`` tokens in *text*, returning the substituted string.
Non-reference text is passed through unchanged.
If *raise_on_error* is ``False`` (default), expansion errors are surfaced
inline as ``[file-ref error: <message>]`` — useful for display/log contexts
where partial expansion is acceptable.
If *raise_on_error* is ``True``, any resolution failure raises
:class:`FileRefExpansionError` immediately so the caller can block the
operation and surface a clean error to the model.
"""
if FILE_REF_PREFIX not in text:
return text
result: list[str] = []
last_end = 0
total_chars = 0
for m in _FILE_REF_RE.finditer(text):
result.append(text[last_end : m.start()])
start = int(m.group(2)) if m.group(2) else None
end = int(m.group(3)) if m.group(3) else None
if (start is not None and start < 1) or (end is not None and end < 1):
msg = f"line numbers must be >= 1: {m.group(0)}"
if raise_on_error:
raise FileRefExpansionError(msg)
result.append(f"[file-ref error: {msg}]")
last_end = m.end()
continue
if start is not None and end is not None and end < start:
msg = f"end line must be >= start line: {m.group(0)}"
if raise_on_error:
raise FileRefExpansionError(msg)
result.append(f"[file-ref error: {msg}]")
last_end = m.end()
continue
ref = FileRef(uri=m.group(1), start_line=start, end_line=end)
try:
content = await resolve_file_ref(ref, user_id, session)
if len(content) > _MAX_EXPAND_CHARS:
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
if remaining <= 0:
content = "[file-ref budget exhausted: total expansion limit reached]"
elif len(content) > remaining:
content = content[:remaining] + "\n... [total budget exhausted]"
total_chars += len(content)
result.append(content)
except ValueError as exc:
logger.warning("file-ref expansion failed for %r: %s", m.group(0), exc)
if raise_on_error:
raise FileRefExpansionError(str(exc)) from exc
result.append(f"[file-ref error: {exc}]")
last_end = m.end()
result.append(text[last_end:])
return "".join(result)
async def expand_file_refs_in_args(
args: dict[str, Any],
user_id: str | None,
session: "ChatSession",
) -> dict[str, Any]:
"""Recursively expand ``@@agptfile:...`` references in tool call arguments.
String values are expanded in-place. Nested dicts and lists are
traversed. Non-string scalars are returned unchanged.
Raises :class:`FileRefExpansionError` if any reference fails to resolve,
so the tool is *not* executed with an error string as its input. The
caller (the MCP tool wrapper) should convert this into an MCP error
response that lets the model correct the reference before retrying.
"""
if not args:
return args
async def _expand(value: Any) -> Any:
if isinstance(value, str):
return await expand_file_refs_in_string(
value, user_id, session, raise_on_error=True
)
if isinstance(value, dict):
return {k: await _expand(v) for k, v in value.items()}
if isinstance(value, list):
return [await _expand(item) for item in value]
return value
return {k: await _expand(v) for k, v in args.items()}

View File

@@ -0,0 +1,328 @@
"""Integration tests for @@agptfile: reference expansion in tool calls.
These tests verify the end-to-end behaviour of the file reference protocol:
- Parsing @@agptfile: tokens from tool arguments
- Resolving local-filesystem paths (sdk_cwd / ephemeral)
- Expanding references inside the tool-call pipeline (_execute_tool_sync)
- The extended Read tool handler (workspace:// pass-through via session context)
No real LLM or database is required; workspace reads are stubbed where needed.
"""
from __future__ import annotations
import os
import tempfile
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.sdk.file_ref import (
FileRef,
expand_file_refs_in_args,
expand_file_refs_in_string,
read_file_bytes,
resolve_file_ref,
)
from backend.copilot.sdk.tool_adapter import _read_file_handler
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_session(session_id: str = "integ-sess") -> MagicMock:
s = MagicMock()
s.session_id = session_id
return s
# ---------------------------------------------------------------------------
# Local-file resolution (sdk_cwd)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_resolve_file_ref_local_path():
"""resolve_file_ref reads a real local file when it's within sdk_cwd."""
with tempfile.TemporaryDirectory() as sdk_cwd:
# Write a test file inside sdk_cwd
test_file = os.path.join(sdk_cwd, "hello.txt")
with open(test_file, "w") as f:
f.write("line1\nline2\nline3\n")
session = _make_session()
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
ref = FileRef(uri=test_file, start_line=None, end_line=None)
content = await resolve_file_ref(ref, user_id="u1", session=session)
assert content == "line1\nline2\nline3\n"
@pytest.mark.asyncio
async def test_resolve_file_ref_local_path_with_line_range():
"""resolve_file_ref respects line ranges for local files."""
with tempfile.TemporaryDirectory() as sdk_cwd:
test_file = os.path.join(sdk_cwd, "multi.txt")
lines = [f"line{i}\n" for i in range(1, 11)] # line1 … line10
with open(test_file, "w") as f:
f.writelines(lines)
session = _make_session()
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
ref = FileRef(uri=test_file, start_line=3, end_line=5)
content = await resolve_file_ref(ref, user_id="u1", session=session)
assert content == "line3\nline4\nline5\n"
@pytest.mark.asyncio
async def test_resolve_file_ref_rejects_path_outside_sdk_cwd():
"""resolve_file_ref raises ValueError for paths outside sdk_cwd."""
with tempfile.TemporaryDirectory() as sdk_cwd:
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var:
mock_cwd_var.get.return_value = sdk_cwd
mock_sandbox_var.get.return_value = None
ref = FileRef(uri="/etc/passwd", start_line=None, end_line=None)
with pytest.raises(ValueError, match="not allowed"):
await resolve_file_ref(ref, user_id="u1", session=_make_session())
# ---------------------------------------------------------------------------
# expand_file_refs_in_string — integration with real files
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_string_with_real_file():
"""expand_file_refs_in_string replaces @@agptfile: token with actual content."""
with tempfile.TemporaryDirectory() as sdk_cwd:
test_file = os.path.join(sdk_cwd, "data.txt")
with open(test_file, "w") as f:
f.write("hello world\n")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_string(
f"Content: @@agptfile:{test_file}",
user_id="u1",
session=_make_session(),
)
assert result == "Content: hello world\n"
@pytest.mark.asyncio
async def test_expand_string_missing_file_is_surfaced_inline():
"""Missing file ref yields [file-ref error: …] inline rather than raising."""
with tempfile.TemporaryDirectory() as sdk_cwd:
missing = os.path.join(sdk_cwd, "does_not_exist.txt")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_string(
f"@@agptfile:{missing}",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "not found" in result.lower() or "not allowed" in result.lower()
# ---------------------------------------------------------------------------
# expand_file_refs_in_args — dict traversal with real files
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_args_replaces_file_ref_in_nested_dict():
"""Nested @@agptfile: references in args are fully expanded."""
with tempfile.TemporaryDirectory() as sdk_cwd:
file_a = os.path.join(sdk_cwd, "a.txt")
file_b = os.path.join(sdk_cwd, "b.txt")
with open(file_a, "w") as f:
f.write("AAA")
with open(file_b, "w") as f:
f.write("BBB")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{
"outer": {
"content_a": f"@@agptfile:{file_a}",
"content_b": f"start @@agptfile:{file_b} end",
},
"count": 42,
},
user_id="u1",
session=_make_session(),
)
assert result["outer"]["content_a"] == "AAA"
assert result["outer"]["content_b"] == "start BBB end"
assert result["count"] == 42
# ---------------------------------------------------------------------------
# _read_file_handler — extended to accept workspace:// and local paths
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_file_handler_local_file():
"""_read_file_handler reads a local file when it's within sdk_cwd."""
with tempfile.TemporaryDirectory() as sdk_cwd:
test_file = os.path.join(sdk_cwd, "read_test.txt")
lines = [f"L{i}\n" for i in range(1, 6)]
with open(test_file, "w") as f:
f.writelines(lines)
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj_var, patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", _make_session()),
):
mock_cwd_var.get.return_value = sdk_cwd
mock_proj_var.get.return_value = ""
result = await _read_file_handler(
{"file_path": test_file, "offset": 0, "limit": 5}
)
assert not result["isError"]
text = result["content"][0]["text"]
assert "L1" in text
assert "L5" in text
@pytest.mark.asyncio
async def test_read_file_handler_workspace_uri():
"""_read_file_handler handles workspace:// URIs via the workspace manager."""
mock_session = _make_session()
mock_manager = AsyncMock()
mock_manager.read_file_by_id.return_value = b"workspace file content\nline two\n"
with patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", mock_session),
), patch(
"backend.copilot.sdk.file_ref.get_manager",
new=AsyncMock(return_value=mock_manager),
):
result = await _read_file_handler(
{"file_path": "workspace://file-id-abc", "offset": 0, "limit": 10}
)
assert not result["isError"], result["content"][0]["text"]
text = result["content"][0]["text"]
assert "workspace file content" in text
assert "line two" in text
@pytest.mark.asyncio
async def test_read_file_handler_workspace_uri_no_session():
"""_read_file_handler returns error when workspace:// is used without session."""
with patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=(None, None),
):
result = await _read_file_handler({"file_path": "workspace://some-id"})
assert result["isError"]
assert "session" in result["content"][0]["text"].lower()
@pytest.mark.asyncio
async def test_read_file_handler_access_denied():
"""_read_file_handler rejects paths outside allowed locations."""
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox, patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", _make_session()),
):
mock_cwd.get.return_value = "/tmp/safe-dir"
mock_sandbox.get.return_value = None
result = await _read_file_handler({"file_path": "/etc/passwd"})
assert result["isError"]
assert "not allowed" in result["content"][0]["text"].lower()
# ---------------------------------------------------------------------------
# read_file_bytes — workspace:///path (virtual path) and E2B sandbox branch
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_file_bytes_workspace_virtual_path():
"""workspace:///path resolves via manager.read_file (is_path=True path)."""
session = _make_session()
mock_manager = AsyncMock()
mock_manager.read_file.return_value = b"virtual path content"
with patch(
"backend.copilot.sdk.file_ref.get_manager",
new=AsyncMock(return_value=mock_manager),
):
result = await read_file_bytes("workspace:///reports/q1.md", "user-1", session)
assert result == b"virtual path content"
mock_manager.read_file.assert_awaited_once_with("/reports/q1.md")
@pytest.mark.asyncio
async def test_read_file_bytes_e2b_sandbox_branch():
"""read_file_bytes reads from the E2B sandbox when a sandbox is active."""
session = _make_session()
mock_sandbox = AsyncMock()
mock_sandbox.files.read.return_value = bytearray(b"sandbox content")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj:
mock_cwd.get.return_value = ""
mock_sandbox_var.get.return_value = mock_sandbox
mock_proj.get.return_value = ""
result = await read_file_bytes("/home/user/script.sh", None, session)
assert result == b"sandbox content"
mock_sandbox.files.read.assert_awaited_once_with(
"/home/user/script.sh", format="bytes"
)
@pytest.mark.asyncio
async def test_read_file_bytes_e2b_path_escapes_sandbox_raises():
"""read_file_bytes raises ValueError for paths that escape the sandbox root."""
session = _make_session()
mock_sandbox = AsyncMock()
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj:
mock_cwd.get.return_value = ""
mock_sandbox_var.get.return_value = mock_sandbox
mock_proj.get.return_value = ""
with pytest.raises(ValueError, match="not allowed"):
await read_file_bytes("/etc/passwd", None, session)

View File

@@ -0,0 +1,382 @@
"""Tests for the @@agptfile: reference protocol (file_ref.py)."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.sdk.file_ref import (
_MAX_EXPAND_CHARS,
FileRef,
FileRefExpansionError,
_apply_line_range,
expand_file_refs_in_args,
expand_file_refs_in_string,
parse_file_ref,
)
# ---------------------------------------------------------------------------
# parse_file_ref
# ---------------------------------------------------------------------------
def test_parse_file_ref_workspace_id():
ref = parse_file_ref("@@agptfile:workspace://abc123")
assert ref == FileRef(uri="workspace://abc123", start_line=None, end_line=None)
def test_parse_file_ref_workspace_id_with_mime():
ref = parse_file_ref("@@agptfile:workspace://abc123#text/plain")
assert ref is not None
assert ref.uri == "workspace://abc123#text/plain"
assert ref.start_line is None
def test_parse_file_ref_workspace_path():
ref = parse_file_ref("@@agptfile:workspace:///reports/q1.md")
assert ref is not None
assert ref.uri == "workspace:///reports/q1.md"
def test_parse_file_ref_with_line_range():
ref = parse_file_ref("@@agptfile:workspace://abc123[10-50]")
assert ref == FileRef(uri="workspace://abc123", start_line=10, end_line=50)
def test_parse_file_ref_local_path():
ref = parse_file_ref("@@agptfile:/tmp/copilot-session/output.py[1-100]")
assert ref is not None
assert ref.uri == "/tmp/copilot-session/output.py"
assert ref.start_line == 1
assert ref.end_line == 100
def test_parse_file_ref_no_match():
assert parse_file_ref("just a normal string") is None
assert parse_file_ref("workspace://abc123") is None # missing @@agptfile: prefix
assert (
parse_file_ref("@@agptfile:workspace://abc123 extra") is None
) # not full match
def test_parse_file_ref_strips_whitespace():
ref = parse_file_ref(" @@agptfile:workspace://abc123 ")
assert ref is not None
assert ref.uri == "workspace://abc123"
def test_parse_file_ref_invalid_range_zero_start():
assert parse_file_ref("@@agptfile:workspace://abc123[0-5]") is None
def test_parse_file_ref_invalid_range_end_less_than_start():
assert parse_file_ref("@@agptfile:workspace://abc123[10-5]") is None
def test_parse_file_ref_invalid_range_zero_end():
assert parse_file_ref("@@agptfile:workspace://abc123[1-0]") is None
# ---------------------------------------------------------------------------
# _apply_line_range
# ---------------------------------------------------------------------------
TEXT = "line1\nline2\nline3\nline4\nline5\n"
def test_apply_line_range_no_range():
assert _apply_line_range(TEXT, None, None) == TEXT
def test_apply_line_range_start_only():
result = _apply_line_range(TEXT, 3, None)
assert result == "line3\nline4\nline5\n"
def test_apply_line_range_full():
result = _apply_line_range(TEXT, 2, 4)
assert result == "line2\nline3\nline4\n"
def test_apply_line_range_single_line():
result = _apply_line_range(TEXT, 2, 2)
assert result == "line2\n"
def test_apply_line_range_beyond_eof():
result = _apply_line_range(TEXT, 4, 999)
assert result == "line4\nline5\n"
# ---------------------------------------------------------------------------
# expand_file_refs_in_string
# ---------------------------------------------------------------------------
def _make_session(session_id: str = "sess-1") -> MagicMock:
session = MagicMock()
session.session_id = session_id
return session
async def _resolve_always(ref: FileRef, _user_id: str | None, _session: object) -> str:
"""Stub resolver that returns the URI and range as a descriptive string."""
if ref.start_line is not None:
return f"content:{ref.uri}[{ref.start_line}-{ref.end_line}]"
return f"content:{ref.uri}"
@pytest.mark.asyncio
async def test_expand_no_refs():
result = await expand_file_refs_in_string(
"no references here", user_id="u1", session=_make_session()
)
assert result == "no references here"
@pytest.mark.asyncio
async def test_expand_single_ref():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://abc123",
user_id="u1",
session=_make_session(),
)
assert result == "content:workspace://abc123"
@pytest.mark.asyncio
async def test_expand_ref_with_range():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://abc123[10-50]",
user_id="u1",
session=_make_session(),
)
assert result == "content:workspace://abc123[10-50]"
@pytest.mark.asyncio
async def test_expand_ref_embedded_in_text():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"Here is the file: @@agptfile:workspace://abc123 — done",
user_id="u1",
session=_make_session(),
)
assert result == "Here is the file: content:workspace://abc123 — done"
@pytest.mark.asyncio
async def test_expand_multiple_refs():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://file1 and @@agptfile:workspace://file2[1-5]",
user_id="u1",
session=_make_session(),
)
assert result == "content:workspace://file1 and content:workspace://file2[1-5]"
@pytest.mark.asyncio
async def test_expand_invalid_range_zero_start_surfaces_inline():
"""expand_file_refs_in_string surfaces [file-ref error: ...] for zero-start ranges."""
result = await expand_file_refs_in_string(
"@@agptfile:workspace://abc123[0-5]",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "line numbers must be >= 1" in result
@pytest.mark.asyncio
async def test_expand_invalid_range_end_less_than_start_surfaces_inline():
"""expand_file_refs_in_string surfaces [file-ref error: ...] when end < start."""
result = await expand_file_refs_in_string(
"prefix @@agptfile:workspace://abc123[10-5] suffix",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "end line must be >= start line" in result
assert "prefix" in result
assert "suffix" in result
@pytest.mark.asyncio
async def test_expand_ref_error_surfaces_inline():
async def _raise(*args, **kwargs): # noqa: ARG001
raise ValueError("file not found")
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_raise),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://bad",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "file not found" in result
# ---------------------------------------------------------------------------
# expand_file_refs_in_args
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_args_flat():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_args(
{"content": "@@agptfile:workspace://abc123", "other": 42},
user_id="u1",
session=_make_session(),
)
assert result["content"] == "content:workspace://abc123"
assert result["other"] == 42
@pytest.mark.asyncio
async def test_expand_args_nested_dict():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_args(
{"outer": {"inner": "@@agptfile:workspace://nested"}},
user_id="u1",
session=_make_session(),
)
assert result["outer"]["inner"] == "content:workspace://nested"
@pytest.mark.asyncio
async def test_expand_args_list():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_args(
{
"items": [
"@@agptfile:workspace://a",
"plain",
"@@agptfile:workspace://b[1-3]",
]
},
user_id="u1",
session=_make_session(),
)
assert result["items"] == [
"content:workspace://a",
"plain",
"content:workspace://b[1-3]",
]
@pytest.mark.asyncio
async def test_expand_args_empty():
result = await expand_file_refs_in_args({}, user_id="u1", session=_make_session())
assert result == {}
@pytest.mark.asyncio
async def test_expand_args_no_refs():
result = await expand_file_refs_in_args(
{"key": "no refs here", "num": 1},
user_id="u1",
session=_make_session(),
)
assert result == {"key": "no refs here", "num": 1}
@pytest.mark.asyncio
async def test_expand_args_raises_on_file_ref_error():
"""expand_file_refs_in_args raises FileRefExpansionError instead of passing
the inline error string to the tool, blocking tool execution."""
async def _raise(*args, **kwargs): # noqa: ARG001
raise ValueError("path does not exist")
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_raise),
):
with pytest.raises(FileRefExpansionError) as exc_info:
await expand_file_refs_in_args(
{"prompt": "@@agptfile:/home/user/missing.txt"},
user_id="u1",
session=_make_session(),
)
assert "path does not exist" in str(exc_info.value)
# ---------------------------------------------------------------------------
# Per-file truncation and aggregate budget
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_per_file_truncation():
"""Content exceeding _MAX_EXPAND_CHARS is truncated with a marker."""
oversized = "x" * (_MAX_EXPAND_CHARS + 100)
async def _resolve_oversized(ref: FileRef, _uid: str | None, _s: object) -> str:
return oversized
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_oversized),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://big-file",
user_id="u1",
session=_make_session(),
)
assert len(result) <= _MAX_EXPAND_CHARS + len("\n... [truncated]") + 10
assert "[truncated]" in result
@pytest.mark.asyncio
async def test_expand_aggregate_budget_exhausted():
"""When the aggregate budget is exhausted, later refs get the budget message."""
# Each file returns just under 300K; after ~4 files the 1M budget is used.
big_chunk = "y" * 300_000
async def _resolve_big(ref: FileRef, _uid: str | None, _s: object) -> str:
return big_chunk
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_big),
):
# 5 refs @ 300K each = 1.5M → last ref(s) should hit the aggregate limit
refs = " ".join(f"@@agptfile:workspace://f{i}" for i in range(5))
result = await expand_file_refs_in_string(
refs,
user_id="u1",
session=_make_session(),
)
assert "budget exhausted" in result

View File

@@ -0,0 +1,28 @@
## MCP Tool Guide
### Workflow
`run_mcp_tool` follows a two-step pattern:
1. **Discover** — call with only `server_url` to list available tools on the server.
2. **Execute** — call again with `server_url`, `tool_name`, and `tool_arguments` to run a tool.
### Known hosted MCP servers
Use these URLs directly without asking the user:
| Service | URL |
|---|---|
| Notion | `https://mcp.notion.com/mcp` |
| Linear | `https://mcp.linear.app/mcp` |
| Stripe | `https://mcp.stripe.com` |
| Intercom | `https://mcp.intercom.com/mcp` |
| Cloudflare | `https://mcp.cloudflare.com/mcp` |
| Atlassian / Jira | `https://mcp.atlassian.com/mcp` |
For other services, search the MCP registry at https://registry.modelcontextprotocol.io/.
### Authentication
If the server requires credentials, a `SetupRequirementsResponse` is returned with an OAuth
login prompt. Once the user completes the flow and confirms, retry the same call immediately.

View File

@@ -536,10 +536,12 @@ async def test_wait_for_stash_signaled():
result = await wait_for_stash(timeout=1.0)
assert result is True
assert _pto.get({}).get("WebSearch") == ["result data"]
pto = _pto.get()
assert pto is not None
assert pto.get("WebSearch") == ["result data"]
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_pto.set({})
_stash_event.set(None)
@@ -554,7 +556,7 @@ async def test_wait_for_stash_timeout():
assert result is False
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_pto.set({})
_stash_event.set(None)
@@ -573,10 +575,12 @@ async def test_wait_for_stash_already_stashed():
assert result is True
# But the stash itself is populated
assert _pto.get({}).get("Read") == ["file contents"]
pto = _pto.get()
assert pto is not None
assert pto.get("Read") == ["file contents"]
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_pto.set({})
_stash_event.set(None)

View File

@@ -10,12 +10,13 @@ import re
from collections.abc import Callable
from typing import Any, cast
from backend.copilot.context import is_allowed_local_path
from .tool_adapter import (
BLOCKED_TOOLS,
DANGEROUS_PATTERNS,
MCP_TOOL_PREFIX,
WORKSPACE_SCOPED_TOOLS,
is_allowed_local_path,
stash_pending_tool_output,
)

View File

@@ -9,8 +9,9 @@ import os
import pytest
from backend.copilot.context import _current_project_dir
from .security_hooks import _validate_tool_access, _validate_user_isolation
from .service import _is_tool_error_or_denial
SDK_CWD = "/tmp/copilot-abc123"
@@ -120,8 +121,6 @@ def test_read_no_cwd_denies_absolute():
def test_read_tool_results_allowed():
from .tool_adapter import _current_project_dir
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
# is_allowed_local_path requires the session's encoded cwd to be set
@@ -133,16 +132,14 @@ def test_read_tool_results_allowed():
_current_project_dir.reset(token)
def test_read_claude_projects_session_dir_allowed():
"""Files within the current session's project dir are allowed."""
from .tool_adapter import _current_project_dir
def test_read_claude_projects_settings_json_denied():
"""SDK-internal artifacts like settings.json are NOT accessible — only tool-results/ is."""
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
token = _current_project_dir.set("-tmp-copilot-abc123")
try:
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert not _is_denied(result)
assert _is_denied(result)
finally:
_current_project_dir.reset(token)
@@ -357,76 +354,3 @@ async def test_task_slot_released_on_failure(_hooks):
context={},
)
assert not _is_denied(result)
# -- _is_tool_error_or_denial ------------------------------------------------
class TestIsToolErrorOrDenial:
def test_none_content(self):
assert _is_tool_error_or_denial(None) is False
def test_empty_content(self):
assert _is_tool_error_or_denial("") is False
def test_benign_output(self):
assert _is_tool_error_or_denial("All good, no issues.") is False
def test_security_marker(self):
assert _is_tool_error_or_denial("[SECURITY] Tool access blocked") is True
def test_cannot_be_bypassed(self):
assert _is_tool_error_or_denial("This restriction cannot be bypassed.") is True
def test_not_allowed(self):
assert _is_tool_error_or_denial("Operation not allowed in sandbox") is True
def test_background_task_denial(self):
assert (
_is_tool_error_or_denial(
"Background task execution is not supported. "
"Run tasks in the foreground instead."
)
is True
)
def test_subtask_limit_denial(self):
assert (
_is_tool_error_or_denial(
"Maximum 2 concurrent sub-tasks. "
"Wait for running sub-tasks to finish, "
"or continue in the main conversation."
)
is True
)
def test_denied_marker(self):
assert (
_is_tool_error_or_denial("Access denied: insufficient privileges") is True
)
def test_blocked_marker(self):
assert _is_tool_error_or_denial("Request blocked by security policy") is True
def test_failed_marker(self):
assert _is_tool_error_or_denial("Failed to execute tool: timeout") is True
def test_mcp_iserror(self):
assert _is_tool_error_or_denial('{"isError": true, "content": []}') is True
def test_benign_error_in_value(self):
"""Content like '0 errors found' should not trigger — 'error' was removed."""
assert _is_tool_error_or_denial("0 errors found") is False
def test_benign_permission_field(self):
"""Schema descriptions mentioning 'permission' should not trigger."""
assert (
_is_tool_error_or_denial(
'{"fields": [{"name": "permission_level", "type": "int"}]}'
)
is False
)
def test_benign_not_found_in_listing(self):
"""File listing containing 'not found' in filenames should not trigger."""
assert _is_tool_error_or_denial("readme.md\nfile-not-found-handler.py") is False

View File

@@ -60,7 +60,7 @@ from ..service import (
_generate_session_title,
_is_langfuse_configured,
)
from ..tools.e2b_sandbox import get_or_create_sandbox
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tools.workspace_files import get_manager
from ..tracking import track_user_message
@@ -456,31 +456,6 @@ def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
def _is_tool_error_or_denial(content: str | None) -> bool:
"""Check if a tool message content indicates an error or denial.
Currently unused — ``_format_conversation_context`` includes all tool
results. Kept as a utility for future selective filtering.
"""
if not content:
return False
lower = content.lower()
return any(
marker in lower
for marker in (
"[security]",
"cannot be bypassed",
"not allowed",
"not supported", # background-task denial
"maximum", # subtask-limit denial
"denied",
"blocked",
"failed to", # internal tool execution failures
'"iserror": true', # MCP protocol error flag
)
)
async def _build_query_message(
current_message: str,
session: ChatSession,
@@ -784,28 +759,29 @@ async def stream_chat_completion_sdk(
async def _setup_e2b():
"""Set up E2B sandbox if configured, return sandbox or None."""
if config.use_e2b_sandbox and not config.e2b_api_key:
logger.warning(
"[E2B] [%s] E2B sandbox enabled but no API key configured "
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
session_id[:12],
)
return None
if config.use_e2b_sandbox and config.e2b_api_key:
try:
return await get_or_create_sandbox(
session_id,
api_key=config.e2b_api_key,
template=config.e2b_sandbox_template,
timeout=config.e2b_sandbox_timeout,
)
except Exception as e2b_err:
logger.error(
"[E2B] [%s] Setup failed: %s",
if not (e2b_api_key := config.active_e2b_api_key):
if config.use_e2b_sandbox:
logger.warning(
"[E2B] [%s] E2B sandbox enabled but no API key configured "
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
session_id[:12],
e2b_err,
exc_info=True,
)
return None
try:
return await get_or_create_sandbox(
session_id,
api_key=e2b_api_key,
template=config.e2b_sandbox_template,
timeout=config.e2b_sandbox_timeout,
on_timeout=config.e2b_sandbox_on_timeout,
)
except Exception as e2b_err:
logger.error(
"[E2B] [%s] Setup failed: %s",
session_id[:12],
e2b_err,
exc_info=True,
)
return None
async def _fetch_transcript():
@@ -837,7 +813,6 @@ async def stream_chat_completion_sdk(
system_prompt = base_system_prompt + get_sdk_supplement(
use_e2b=use_e2b, cwd=sdk_cwd
)
# Process transcript download result
transcript_msg_count = 0
if dl:
@@ -902,6 +877,11 @@ async def stream_chat_completion_sdk(
allowed = get_copilot_tool_names(use_e2b=use_e2b)
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
def _on_stderr(line: str) -> None:
sid = session_id[:12] if session_id else "?"
logger.info("[SDK] [%s] CLI stderr: %s", sid, line.rstrip())
sdk_options_kwargs: dict[str, Any] = {
"system_prompt": system_prompt,
"mcp_servers": {"copilot": mcp_server},
@@ -910,6 +890,7 @@ async def stream_chat_completion_sdk(
"hooks": security_hooks,
"cwd": sdk_cwd,
"max_buffer_size": config.claude_agent_max_buffer_size,
"stderr": _on_stderr,
}
if sdk_model:
sdk_options_kwargs["model"] = sdk_model
@@ -1082,6 +1063,19 @@ async def stream_chat_completion_sdk(
len(adapter.resolved_tool_calls),
)
# Log AssistantMessage API errors (e.g. invalid_request)
# so we can debug Anthropic API 400s surfaced by the CLI.
sdk_error = getattr(sdk_msg, "error", None)
if isinstance(sdk_msg, AssistantMessage) and sdk_error:
logger.error(
"[SDK] [%s] AssistantMessage has error=%s, "
"content_blocks=%d, content_preview=%s",
session_id[:12],
sdk_error,
len(sdk_msg.content),
str(sdk_msg.content)[:500],
)
# Race-condition fix: SDK hooks (PostToolUse) are
# executed asynchronously via start_soon() — the next
# message can arrive before the hook stashes output.
@@ -1416,6 +1410,17 @@ async def stream_chat_completion_sdk(
exc_info=True,
)
# --- Pause E2B sandbox to stop billing between turns ---
# Fire-and-forget: pausing is best-effort and must not block the
# response or the transcript upload. The task is anchored to
# _background_tasks to prevent garbage collection.
# Use pause_sandbox_direct to skip the Redis lookup and reconnect
# round-trip — e2b_sandbox is the live object from this turn.
if e2b_sandbox is not None:
task = asyncio.create_task(pause_sandbox_direct(e2b_sandbox, session_id))
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# --- Upload transcript for next-turn --resume ---
# This MUST run in finally so the transcript is uploaded even when
# the streaming loop raises an exception.

View File

@@ -1,9 +1,10 @@
"""Tests for SDK service helpers."""
import asyncio
import base64
import os
from dataclasses import dataclass
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -212,7 +213,7 @@ class TestPromptSupplement:
# Workflows are now in individual tool descriptions (not separate sections)
# Check that key workflow concepts appear in tool descriptions
assert "suggested_goal" in docs or "clarifying_questions" in docs
assert "agent_json" in docs or "find_block" in docs
assert "run_mcp_tool" in docs
def test_baseline_supplement_completeness(self):
@@ -231,6 +232,48 @@ class TestPromptSupplement:
f"`{tool_name}`" in docs
), f"Tool '{tool_name}' missing from baseline supplement"
def test_pause_task_scheduled_before_transcript_upload(self):
"""Pause is scheduled as a background task before transcript upload begins.
The finally block in stream_response_sdk does:
(1) asyncio.create_task(pause_sandbox_direct(...)) — fire-and-forget
(2) await asyncio.shield(upload_transcript(...)) — awaited
Scheduling pause via create_task before awaiting upload ensures:
- Pause never blocks transcript upload (billing stops concurrently)
- On E2B timeout, pause silently fails; upload proceeds unaffected
"""
call_order: list[str] = []
async def _mock_pause(sandbox, session_id):
call_order.append("pause")
async def _mock_upload(**kwargs):
call_order.append("upload")
async def _simulate_teardown():
"""Mirror the service.py finally block teardown sequence."""
sandbox = MagicMock()
# (1) Schedule pause — mirrors lines ~1427-1429 in service.py
task = asyncio.create_task(_mock_pause(sandbox, "test-sess"))
# (2) Await transcript upload — mirrors lines ~1460-1468 in service.py
# Yielding to the event loop here lets the pause task start concurrently.
await _mock_upload(
user_id="u", session_id="test-sess", content="x", message_count=1
)
await task
asyncio.run(_simulate_teardown())
# Both must run; pause is scheduled before upload starts
assert "pause" in call_order
assert "upload" in call_order
# create_task schedules pause, then upload is awaited — pause runs
# concurrently during upload's first yield. The ordering guarantee is
# that create_task is CALLED before upload is AWAITED (see source order).
def test_baseline_supplement_no_duplicate_tools(self):
"""No tool should appear multiple times in baseline supplement."""
from backend.copilot.prompting import get_baseline_supplement

View File

@@ -9,14 +9,29 @@ import itertools
import json
import logging
import os
import re
import uuid
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any
from claude_agent_sdk import create_sdk_mcp_server, tool
from backend.copilot.context import (
_current_project_dir,
_current_sandbox,
_current_sdk_cwd,
_current_session,
_current_user_id,
_encode_cwd_for_cli,
get_execution_context,
get_sdk_cwd,
is_allowed_local_path,
)
from backend.copilot.model import ChatSession
from backend.copilot.sdk.file_ref import (
FileRefExpansionError,
expand_file_refs_in_args,
read_file_bytes,
)
from backend.copilot.tools import TOOL_REGISTRY
from backend.copilot.tools.base import BaseTool
from backend.util.truncate import truncate
@@ -28,84 +43,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
# in the path — prevents reading settings, credentials, or other sensitive files.
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
_MCP_MAX_CHARS = 500_000
# Context variable holding the encoded project directory name for the current
# session (e.g. "-private-tmp-copilot-<uuid>"). Set by set_execution_context()
# so that path validation can scope tool-results reads to the current session.
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does.
The CLI replaces all non-alphanumeric characters with ``-``.
"""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
"""Check whether *path* is an allowed host-filesystem path.
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/`` — the SDK's
project directory for this session (tool-results, transcripts, etc.)
Both checks are scoped to the **current session** so sessions cannot
read each other's data.
"""
if not path:
return False
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
# Allow access within the SDK working directory
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
return True
# Allow access within the current session's CLI project directory
# (~/.claude/projects/<encoded-cwd>/).
encoded = _current_project_dir.get("")
if encoded:
session_project = os.path.join(_SDK_PROJECTS_DIR, encoded)
if resolved == session_project or resolved.startswith(session_project + os.sep):
return True
return False
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
MCP_SERVER_NAME = "copilot"
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
# Context variables to pass user/session info to tool execution
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
# E2B cloud sandbox for the current turn (None when E2B is not configured).
# Passed to bash_exec so commands run on E2B instead of the local bwrap sandbox.
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
"_current_sandbox", default=None
)
# Raw SDK working directory path (e.g. /tmp/copilot-<session_id>).
# Used by workspace tools to save binary files for the CLI's built-in Read.
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
# Stash for MCP tool outputs before the SDK potentially truncates them.
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
@@ -149,24 +93,6 @@ def set_execution_context(
_stash_event.set(asyncio.Event())
def get_current_sandbox() -> "AsyncSandbox | None":
"""Return the E2B sandbox for the current turn, or None."""
return _current_sandbox.get()
def get_sdk_cwd() -> str:
"""Return the SDK ephemeral working directory for the current turn."""
return _current_sdk_cwd.get()
def get_execution_context() -> tuple[str | None, ChatSession | None]:
"""Get the current execution context."""
return (
_current_user_id.get(),
_current_session.get(),
)
def pop_pending_tool_output(tool_name: str) -> str | None:
"""Pop and return the oldest stashed output for *tool_name*.
@@ -259,7 +185,11 @@ async def _execute_tool_sync(
session: ChatSession,
args: dict[str, Any],
) -> dict[str, Any]:
"""Execute a tool synchronously and return MCP-formatted response."""
"""Execute a tool synchronously and return MCP-formatted response.
Note: ``@@agptfile:`` expansion is handled upstream in the ``_truncating`` wrapper
so all registered handlers (BaseTool, E2B, Read) expand uniformly.
"""
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
result = await base_tool.execute(
user_id=user_id,
@@ -320,42 +250,50 @@ def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Read a local file with optional offset/limit.
"""Read a file with optional offset/limit.
Only allows paths that pass :func:`is_allowed_local_path` — the current
session's tool-results directory and ephemeral working directory.
Supports ``workspace://`` URIs (delegated to the workspace manager) and
local paths within the session's allowed directories (sdk_cwd + tool-results).
"""
file_path = args.get("file_path", "")
offset = args.get("offset", 0)
limit = args.get("limit", 2000)
offset = max(0, int(args.get("offset", 0)))
limit = max(1, int(args.get("limit", 2000)))
if not is_allowed_local_path(file_path):
return {
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
"isError": True,
}
def _mcp_err(text: str) -> dict[str, Any]:
return {"content": [{"type": "text", "text": text}], "isError": True}
def _mcp_ok(text: str) -> dict[str, Any]:
return {"content": [{"type": "text", "text": text}], "isError": False}
if file_path.startswith("workspace://"):
user_id, session = get_execution_context()
if session is None:
return _mcp_err("workspace:// file references require an active session")
try:
raw = await read_file_bytes(file_path, user_id, session)
except ValueError as exc:
return _mcp_err(str(exc))
lines = raw.decode("utf-8", errors="replace").splitlines(keepends=True)
selected = list(itertools.islice(lines, offset, offset + limit))
numbered = "".join(
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
)
return _mcp_ok(numbered)
if not is_allowed_local_path(file_path, get_sdk_cwd()):
return _mcp_err(f"Path not allowed: {file_path}")
resolved = os.path.realpath(os.path.expanduser(file_path))
try:
with open(resolved) as f:
selected = list(itertools.islice(f, offset, offset + limit))
content = "".join(selected)
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
# don't delete here — the SDK may read in multiple chunks.
return {
"content": [{"type": "text", "text": content}],
"isError": False,
}
return _mcp_ok("".join(selected))
except FileNotFoundError:
return {
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
"isError": True,
}
return _mcp_err(f"File not found: {file_path}")
except Exception as e:
return {
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
"isError": True,
}
return _mcp_err(f"Error reading file: {e}")
_READ_TOOL_NAME = "Read"
@@ -414,9 +352,23 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
response adapter before the SDK can apply its own head-truncation.
Also expands ``@@agptfile:`` references in args so every registered tool
(BaseTool, E2B file tools, Read) receives resolved content uniformly.
Applied once to every registered tool."""
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
user_id, session = get_execution_context()
if session is not None:
try:
args = await expand_file_refs_in_args(args, user_id, session)
except FileRefExpansionError as exc:
return _mcp_error(
f"@@agptfile: reference could not be resolved: {exc}. "
"Ensure the file exists before referencing it. "
"For sandbox paths use bash_exec to verify the file exists first; "
"for workspace files use a workspace:// URI."
)
result = await fn(args)
truncated = truncate(result, _MCP_MAX_CHARS)

View File

@@ -2,12 +2,12 @@
import pytest
from backend.copilot.context import get_sdk_cwd
from backend.util.truncate import truncate
from .tool_adapter import (
_MCP_MAX_CHARS,
_text_from_mcp_result,
get_sdk_cwd,
pop_pending_tool_output,
set_execution_context,
stash_pending_tool_output,

View File

@@ -19,7 +19,10 @@ from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsToo
from .find_agent import FindAgentTool
from .find_block import FindBlockTool
from .find_library_agent import FindLibraryAgentTool
from .fix_agent import FixAgentGraphTool
from .get_agent_building_guide import GetAgentBuildingGuideTool
from .get_doc_page import GetDocPageTool
from .get_mcp_guide import GetMCPGuideTool
from .manage_folders import (
CreateFolderTool,
DeleteFolderTool,
@@ -32,6 +35,7 @@ from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .run_mcp_tool import RunMCPToolTool
from .search_docs import SearchDocsTool
from .validate_agent import ValidateAgentGraphTool
from .web_fetch import WebFetchTool
from .workspace_files import (
DeleteWorkspaceFileTool,
@@ -65,9 +69,11 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"run_mcp_tool": RunMCPToolTool(),
"get_mcp_guide": GetMCPGuideTool(),
"view_agent_output": AgentOutputTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
"get_agent_building_guide": GetAgentBuildingGuideTool(),
# Web fetch for safe URL retrieval
"web_fetch": WebFetchTool(),
# Agent-browser multi-step automation (navigate, act, screenshot)
@@ -80,6 +86,9 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
# Feature request tools
"search_feature_requests": SearchFeatureRequestsTool(),
"create_feature_request": CreateFeatureRequestTool(),
# Agent generation tools (local validation/fixing)
"validate_agent_graph": ValidateAgentGraphTool(),
"fix_agent_graph": FixAgentGraphTool(),
# Workspace tools for CoPilot file operations
"list_workspace_files": ListWorkspaceFilesTool(),
"read_workspace_file": ReadWorkspaceFileTool(),

View File

@@ -33,7 +33,7 @@ import tempfile
from typing import Any
from backend.copilot.model import ChatSession
from backend.util.request import validate_url
from backend.util.request import validate_url_host
from .base import BaseTool
from .models import (
@@ -235,7 +235,7 @@ async def _restore_browser_state(
if url:
# Validate the saved URL to prevent SSRF via stored redirect targets.
try:
await validate_url(url, trusted_origins=[])
await validate_url_host(url)
except ValueError:
logger.warning(
"[browser] State restore: blocked SSRF URL %s", url[:200]
@@ -473,7 +473,7 @@ class BrowserNavigateTool(BaseTool):
)
try:
await validate_url(url, trusted_origins=[])
await validate_url_host(url)
except ValueError as e:
return ErrorResponse(
message=str(e),

View File

@@ -68,17 +68,18 @@ def _run_result(rc: int = 0, stdout: str = "", stderr: str = "") -> tuple:
# ---------------------------------------------------------------------------
# SSRF protection via shared validate_url (backend.util.request)
# SSRF protection via shared validate_url_host (backend.util.request)
# ---------------------------------------------------------------------------
# Patch target: validate_url is imported directly into agent_browser's module scope.
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url"
# Patch target: validate_url_host is imported directly into agent_browser's
# module scope.
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url_host"
class TestSsrfViaValidateUrl:
"""Verify that browser_navigate uses validate_url for SSRF protection.
"""Verify that browser_navigate uses validate_url_host for SSRF protection.
We mock validate_url itself (not the low-level socket) so these tests
We mock validate_url_host itself (not the low-level socket) so these tests
exercise the integration point, not the internals of request.py
(which has its own thorough test suite in request_test.py).
"""
@@ -89,7 +90,7 @@ class TestSsrfViaValidateUrl:
@pytest.mark.asyncio
async def test_blocked_ip_returns_blocked_url_error(self):
"""validate_url raises ValueError → tool returns blocked_url ErrorResponse."""
"""validate_url_host raises ValueError → tool returns blocked_url ErrorResponse."""
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
mock_validate.side_effect = ValueError(
"Access to blocked IP 10.0.0.1 is not allowed."
@@ -124,8 +125,8 @@ class TestSsrfViaValidateUrl:
assert result.error == "blocked_url"
@pytest.mark.asyncio
async def test_validate_url_called_with_empty_trusted_origins(self):
"""Confirms no trusted-origins bypass is granted — all URLs are validated."""
async def test_validate_url_host_called_without_trusted_hostnames(self):
"""Confirms no trusted-hostnames bypass is granted — all URLs are validated."""
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
mock_validate.return_value = (object(), False, ["1.2.3.4"])
with patch(
@@ -143,7 +144,7 @@ class TestSsrfViaValidateUrl:
session=self.session,
url="https://example.com",
)
mock_validate.assert_called_once_with("https://example.com", trusted_origins=[])
mock_validate.assert_called_once_with("https://example.com")
# ---------------------------------------------------------------------------

View File

@@ -1,20 +1,15 @@
"""Agent generator package - Creates agents from natural language."""
from .core import (
AgentGeneratorNotConfiguredError,
AgentJsonValidationError,
AgentSummary,
DecompositionResult,
DecompositionStep,
LibraryAgentSummary,
MarketplaceAgentSummary,
customize_template,
decompose_goal,
enrich_library_agents_from_steps,
extract_search_terms_from_steps,
extract_uuids_from_text,
generate_agent,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_library_agent_by_graph_id,
@@ -27,25 +22,20 @@ from .core import (
search_marketplace_agents_for_generation,
)
from .errors import get_user_message_for_error
from .service import health_check as check_external_service_health
from .service import is_external_service_configured
from .validation import AgentFixer, AgentValidator
__all__ = [
"AgentGeneratorNotConfiguredError",
"AgentFixer",
"AgentValidator",
"AgentJsonValidationError",
"AgentSummary",
"DecompositionResult",
"DecompositionStep",
"LibraryAgentSummary",
"MarketplaceAgentSummary",
"check_external_service_health",
"customize_template",
"decompose_goal",
"enrich_library_agents_from_steps",
"extract_search_terms_from_steps",
"extract_uuids_from_text",
"generate_agent",
"generate_agent_patch",
"get_agent_as_json",
"get_all_relevant_agents_for_generation",
"get_library_agent_by_graph_id",
@@ -54,7 +44,6 @@ __all__ = [
"get_library_agents_for_generation",
"get_user_message_for_error",
"graph_to_json",
"is_external_service_configured",
"json_to_graph",
"save_agent_to_library",
"search_marketplace_agents_for_generation",

View File

@@ -0,0 +1,66 @@
"""Block management for agent generation.
Provides cached access to block metadata for validation and fixing.
"""
import logging
from typing import Any, Type
from backend.blocks import get_blocks as get_block_classes
from backend.blocks._base import Block
logger = logging.getLogger(__name__)
__all__ = ["get_blocks_as_dicts", "reset_block_caches"]
# ---------------------------------------------------------------------------
# Module-level caches
# ---------------------------------------------------------------------------
_blocks_cache: list[dict[str, Any]] | None = None
def reset_block_caches() -> None:
"""Reset all module-level caches (useful after updating block descriptions)."""
global _blocks_cache
_blocks_cache = None
# ---------------------------------------------------------------------------
# 1. get_blocks_as_dicts
# ---------------------------------------------------------------------------
def get_blocks_as_dicts() -> list[dict[str, Any]]:
"""Get all available blocks as dicts (cached after first call).
Each dict contains the keys returned by ``Block.get_info().model_dump()``:
id, name, description, inputSchema, outputSchema, categories,
staticOutput, costs, contributors, uiType.
Returns:
List of block info dicts.
"""
global _blocks_cache
if _blocks_cache is not None:
return _blocks_cache
block_classes: dict[str, Type[Block]] = get_block_classes() # type: ignore[assignment]
blocks: list[dict[str, Any]] = []
for block_cls in block_classes.values():
try:
instance = block_cls()
info = instance.get_info().model_dump()
# Use optimized description if available (loaded at startup)
if instance.optimized_description:
info["description"] = instance.optimized_description
blocks.append(info)
except Exception:
logger.warning(
"Failed to load block info for %s, skipping",
getattr(block_cls, "__name__", "unknown"),
exc_info=True,
)
_blocks_cache = blocks
logger.info("Cached %d block dicts", len(blocks))
return _blocks_cache

View File

@@ -10,13 +10,7 @@ from backend.data.db_accessors import graph_db, library_db, store_db
from backend.data.graph import Graph, Link, Node
from backend.util.exceptions import DatabaseError, NotFoundError
from .service import (
customize_template_external,
decompose_goal_external,
generate_agent_external,
generate_agent_patch_external,
is_external_service_configured,
)
from .helpers import UUID_RE_STR
logger = logging.getLogger(__name__)
@@ -78,38 +72,7 @@ class DecompositionResult(TypedDict, total=False):
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
def _to_dict_list(
agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None,
) -> list[dict[str, Any]] | None:
"""Convert typed agent summaries to plain dicts for external service calls."""
if agents is None:
return None
return [dict(a) for a in agents]
class AgentGeneratorNotConfiguredError(Exception):
"""Raised when the external Agent Generator service is not configured."""
pass
def _check_service_configured() -> None:
"""Check if the external Agent Generator service is configured.
Raises:
AgentGeneratorNotConfiguredError: If the service is not configured.
"""
if not is_external_service_configured():
raise AgentGeneratorNotConfiguredError(
"Agent Generator service is not configured. "
"Set AGENTGENERATOR_HOST environment variable to enable agent generation."
)
_UUID_PATTERN = re.compile(
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
re.IGNORECASE,
)
_UUID_PATTERN = re.compile(UUID_RE_STR, re.IGNORECASE)
def extract_uuids_from_text(text: str) -> list[str]:
@@ -553,69 +516,6 @@ async def enrich_library_agents_from_steps(
return all_agents
async def decompose_goal(
description: str,
context: str = "",
library_agents: Sequence[AgentSummary] | None = None,
) -> DecompositionResult | None:
"""Break down a goal into steps or return clarifying questions.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
library_agents: User's library agents available for sub-agent composition
Returns:
DecompositionResult with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
Or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for decompose_goal")
result = await decompose_goal_external(
description, context, _to_dict_list(library_agents)
)
return result # type: ignore[return-value]
async def generate_agent(
instructions: DecompositionResult | dict[str, Any],
library_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
Returns:
Agent JSON dict, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(
dict(instructions), _to_dict_list(library_agents)
)
if result:
if isinstance(result, dict) and result.get("type") == "error":
return result
if "id" not in result:
result["id"] = str(uuid.uuid4())
if "version" not in result:
result["version"] = 1
if "is_active" not in result:
result["is_active"] = True
return result
class AgentJsonValidationError(Exception):
"""Raised when agent JSON is invalid or missing required fields."""
@@ -792,70 +692,3 @@ async def get_agent_as_json(
return None
return graph_to_json(graph)
async def generate_agent_patch(
update_request: str,
current_agent: dict[str, Any],
library_agents: Sequence[AgentSummary] | None = None,
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
The external Agent Generator service handles:
- Generating the patch
- Applying the patch
- Fixing and validating the result
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
Returns:
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent_patch")
return await generate_agent_patch_external(
update_request,
current_agent,
_to_dict_list(library_agents),
)
async def customize_template(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Customize a template/marketplace agent using natural language.
This is used when users want to modify a template or marketplace agent
to fit their specific needs before adding it to their library.
The external Agent Generator service handles:
- Understanding the modification request
- Applying changes to the template
- Fixing and validating the result
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
Returns:
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on unexpected error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for customize_template")
return await customize_template_external(
template_agent, modification_request, context
)

View File

@@ -1,165 +0,0 @@
"""Dummy Agent Generator for testing.
Returns mock responses matching the format expected from the external service.
Enable via AGENTGENERATOR_USE_DUMMY=true in settings.
WARNING: This is for testing only. Do not use in production.
"""
import asyncio
import logging
import uuid
from typing import Any
logger = logging.getLogger(__name__)
# Dummy decomposition result (instructions type)
DUMMY_DECOMPOSITION_RESULT: dict[str, Any] = {
"type": "instructions",
"steps": [
{
"description": "Get input from user",
"action": "input",
"block_name": "AgentInputBlock",
},
{
"description": "Process the input",
"action": "process",
"block_name": "TextFormatterBlock",
},
{
"description": "Return output to user",
"action": "output",
"block_name": "AgentOutputBlock",
},
],
}
# Block IDs from backend/blocks/io.py
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
def _generate_dummy_agent_json() -> dict[str, Any]:
"""Generate a minimal valid agent JSON for testing."""
input_node_id = str(uuid.uuid4())
output_node_id = str(uuid.uuid4())
return {
"id": str(uuid.uuid4()),
"version": 1,
"is_active": True,
"name": "Dummy Test Agent",
"description": "A dummy agent generated for testing purposes",
"nodes": [
{
"id": input_node_id,
"block_id": AGENT_INPUT_BLOCK_ID,
"input_default": {
"name": "input",
"title": "Input",
"description": "Enter your input",
"placeholder_values": [],
},
"metadata": {"position": {"x": 0, "y": 0}},
},
{
"id": output_node_id,
"block_id": AGENT_OUTPUT_BLOCK_ID,
"input_default": {
"name": "output",
"title": "Output",
"description": "Agent output",
"format": "{output}",
},
"metadata": {"position": {"x": 400, "y": 0}},
},
],
"links": [
{
"id": str(uuid.uuid4()),
"source_id": input_node_id,
"sink_id": output_node_id,
"source_name": "result",
"sink_name": "value",
"is_static": False,
},
],
}
async def decompose_goal_dummy(
description: str,
context: str = "",
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
"""Return dummy decomposition result."""
logger.info("Using dummy agent generator for decompose_goal")
return DUMMY_DECOMPOSITION_RESULT.copy()
async def generate_agent_dummy(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy agent synchronously (blocks for 30s, returns agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator (sync mode): returning agent JSON after 30s"
)
await asyncio.sleep(30)
return _generate_dummy_agent_json()
async def generate_agent_patch_dummy(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy patched agent synchronously (blocks for 30s, returns patched agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator patch (sync mode): returning patched agent after 30s"
)
await asyncio.sleep(30)
patched = current_agent.copy()
patched["description"] = (
f"{current_agent.get('description', '')} (updated: {update_request})"
)
return patched
async def customize_template_dummy(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any]:
"""Return dummy customized template (returns template with updated description)."""
logger.info("Using dummy agent generator for customize_template")
customized = template_agent.copy()
customized["description"] = (
f"{template_agent.get('description', '')} (customized: {modification_request})"
)
return customized
async def get_blocks_dummy() -> list[dict[str, Any]]:
"""Return dummy blocks list."""
logger.info("Using dummy agent generator for get_blocks")
return [
{"id": AGENT_INPUT_BLOCK_ID, "name": "AgentInputBlock"},
{"id": AGENT_OUTPUT_BLOCK_ID, "name": "AgentOutputBlock"},
]
async def health_check_dummy() -> bool:
"""Always returns healthy for dummy service."""
return True

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,913 @@
"""Unit tests for AgentFixer."""
from .fixer import (
_ADDTODICTIONARY_BLOCK_ID,
_ADDTOLIST_BLOCK_ID,
_CODE_EXECUTION_BLOCK_ID,
_DATA_SAMPLING_BLOCK_ID,
_GET_CURRENT_DATE_BLOCK_ID,
_STORE_VALUE_BLOCK_ID,
_TEXT_REPLACE_BLOCK_ID,
_UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
AGENT_EXECUTOR_BLOCK_ID,
MCP_TOOL_BLOCK_ID,
AgentFixer,
)
from .helpers import generate_uuid
def _make_agent(
nodes: list | None = None,
links: list | None = None,
agent_id: str | None = None,
) -> dict:
"""Build a minimal agent dict for testing."""
return {
"id": agent_id or generate_uuid(),
"name": "Test Agent",
"nodes": nodes or [],
"links": links or [],
}
def _make_node(
node_id: str | None = None,
block_id: str = "block-1",
input_default: dict | None = None,
position: tuple[int, int] = (0, 0),
) -> dict:
"""Build a minimal node dict for testing."""
return {
"id": node_id or generate_uuid(),
"block_id": block_id,
"input_default": input_default or {},
"metadata": {"position": {"x": position[0], "y": position[1]}},
}
def _make_link(
link_id: str | None = None,
source_id: str = "",
source_name: str = "output",
sink_id: str = "",
sink_name: str = "input",
is_static: bool = False,
) -> dict:
"""Build a minimal link dict for testing."""
return {
"id": link_id or generate_uuid(),
"source_id": source_id,
"source_name": source_name,
"sink_id": sink_id,
"sink_name": sink_name,
"is_static": is_static,
}
class TestFixAgentIds:
"""Tests for fix_agent_ids."""
def test_valid_uuids_unchanged(self):
fixer = AgentFixer()
agent_id = generate_uuid()
link_id = generate_uuid()
agent = _make_agent(agent_id=agent_id, links=[{"id": link_id}])
result = fixer.fix_agent_ids(agent)
assert result["id"] == agent_id
assert result["links"][0]["id"] == link_id
assert fixer.fixes_applied == []
def test_invalid_agent_id_replaced(self):
fixer = AgentFixer()
agent = _make_agent(agent_id="bad-id")
result = fixer.fix_agent_ids(agent)
assert result["id"] != "bad-id"
assert len(fixer.fixes_applied) == 1
assert "agent ID" in fixer.fixes_applied[0]
def test_invalid_link_id_replaced(self):
fixer = AgentFixer()
agent = _make_agent(links=[{"id": "not-a-uuid"}])
result = fixer.fix_agent_ids(agent)
assert result["links"][0]["id"] != "not-a-uuid"
assert len(fixer.fixes_applied) == 1
class TestFixDoubleCurlyBraces:
"""Tests for fix_double_curly_braces."""
def test_single_braces_converted_to_double(self):
fixer = AgentFixer()
node = _make_node(input_default={"prompt": "Hello {name}!"})
agent = _make_agent(nodes=[node])
result = fixer.fix_double_curly_braces(agent)
assert result["nodes"][0]["input_default"]["prompt"] == "Hello {{name}}!"
def test_double_braces_unchanged(self):
fixer = AgentFixer()
node = _make_node(input_default={"prompt": "Hello {{name}}!"})
agent = _make_agent(nodes=[node])
result = fixer.fix_double_curly_braces(agent)
assert result["nodes"][0]["input_default"]["prompt"] == "Hello {{name}}!"
assert fixer.fixes_applied == []
def test_non_string_prompt_skipped(self):
fixer = AgentFixer()
node = _make_node(input_default={"prompt": 42})
agent = _make_agent(nodes=[node])
result = fixer.fix_double_curly_braces(agent)
assert result["nodes"][0]["input_default"]["prompt"] == 42
def test_non_string_prompt_with_prompt_values_skipped(self):
"""Ensure non-string prompt fields don't crash re.search in the
prompt_values path."""
fixer = AgentFixer()
node_id = generate_uuid()
source_id = generate_uuid()
node = _make_node(
node_id=node_id, input_default={"prompt": None, "prompt_values": {}}
)
source_node = _make_node(node_id=source_id)
link = _make_link(
source_id=source_id,
source_name="output",
sink_id=node_id,
sink_name="prompt_values_$_name",
)
agent = _make_agent(nodes=[node, source_node], links=[link])
result = fixer.fix_double_curly_braces(agent)
# Should not crash and prompt stays None
assert result["nodes"][0]["input_default"]["prompt"] is None
class TestFixCredentials:
"""Tests for fix_credentials."""
def test_credentials_removed(self):
fixer = AgentFixer()
node = _make_node(
input_default={
"credentials": {"key": "secret"},
"url": "http://example.com",
}
)
agent = _make_agent(nodes=[node])
result = fixer.fix_credentials(agent)
assert "credentials" not in result["nodes"][0]["input_default"]
assert result["nodes"][0]["input_default"]["url"] == "http://example.com"
assert len(fixer.fixes_applied) == 1
def test_no_credentials_unchanged(self):
fixer = AgentFixer()
node = _make_node(input_default={"url": "http://example.com"})
agent = _make_agent(nodes=[node])
result = fixer.fix_credentials(agent)
assert result["nodes"][0]["input_default"]["url"] == "http://example.com"
assert fixer.fixes_applied == []
class TestFixCodeExecutionOutput:
"""Tests for fix_code_execution_output."""
def test_response_renamed_to_stdout_logs(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id=_CODE_EXECUTION_BLOCK_ID)
link = _make_link(source_id="n1", source_name="response", sink_id="n2")
agent = _make_agent(nodes=[node], links=[link])
result = fixer.fix_code_execution_output(agent)
assert result["links"][0]["source_name"] == "stdout_logs"
assert len(fixer.fixes_applied) == 1
def test_non_response_source_unchanged(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id=_CODE_EXECUTION_BLOCK_ID)
link = _make_link(source_id="n1", source_name="stdout_logs", sink_id="n2")
agent = _make_agent(nodes=[node], links=[link])
result = fixer.fix_code_execution_output(agent)
assert result["links"][0]["source_name"] == "stdout_logs"
assert fixer.fixes_applied == []
class TestFixDataSamplingSampleSize:
"""Tests for fix_data_sampling_sample_size."""
def test_sample_size_set_to_1(self):
fixer = AgentFixer()
node = _make_node(
node_id="n1",
block_id=_DATA_SAMPLING_BLOCK_ID,
input_default={"sample_size": 10},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_data_sampling_sample_size(agent)
assert result["nodes"][0]["input_default"]["sample_size"] == 1
def test_removes_links_to_sample_size(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id=_DATA_SAMPLING_BLOCK_ID)
link = _make_link(sink_id="n1", sink_name="sample_size", source_id="n2")
agent = _make_agent(nodes=[node], links=[link])
result = fixer.fix_data_sampling_sample_size(agent)
assert len(result["links"]) == 0
assert result["nodes"][0]["input_default"]["sample_size"] == 1
class TestFixTextReplaceNewParameter:
"""Tests for fix_text_replace_new_parameter."""
def test_empty_new_changed_to_space(self):
fixer = AgentFixer()
node = _make_node(
block_id=_TEXT_REPLACE_BLOCK_ID,
input_default={"new": ""},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_text_replace_new_parameter(agent)
assert result["nodes"][0]["input_default"]["new"] == " "
def test_nonempty_new_unchanged(self):
fixer = AgentFixer()
node = _make_node(
block_id=_TEXT_REPLACE_BLOCK_ID,
input_default={"new": "replacement"},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_text_replace_new_parameter(agent)
assert result["nodes"][0]["input_default"]["new"] == "replacement"
assert fixer.fixes_applied == []
class TestFixGetCurrentDateOffset:
"""Tests for fix_getcurrentdate_offset."""
def test_negative_offset_made_positive(self):
fixer = AgentFixer()
node = _make_node(
block_id=_GET_CURRENT_DATE_BLOCK_ID,
input_default={"offset": -5},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_getcurrentdate_offset(agent)
assert result["nodes"][0]["input_default"]["offset"] == 5
def test_positive_offset_unchanged(self):
fixer = AgentFixer()
node = _make_node(
block_id=_GET_CURRENT_DATE_BLOCK_ID,
input_default={"offset": 3},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_getcurrentdate_offset(agent)
assert result["nodes"][0]["input_default"]["offset"] == 3
assert fixer.fixes_applied == []
class TestFixNodeXCoordinates:
"""Tests for fix_node_x_coordinates."""
def test_close_nodes_spread_apart(self):
fixer = AgentFixer()
src_node = _make_node(node_id="src", position=(0, 0))
sink_node = _make_node(node_id="sink", position=(100, 0))
link = _make_link(source_id="src", sink_id="sink")
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
result = fixer.fix_node_x_coordinates(agent)
sink = next(n for n in result["nodes"] if n["id"] == "sink")
assert sink["metadata"]["position"]["x"] >= 800
def test_far_apart_nodes_unchanged(self):
fixer = AgentFixer()
src_node = _make_node(node_id="src", position=(0, 0))
sink_node = _make_node(node_id="sink", position=(1000, 0))
link = _make_link(source_id="src", sink_id="sink")
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
result = fixer.fix_node_x_coordinates(agent)
sink = next(n for n in result["nodes"] if n["id"] == "sink")
assert sink["metadata"]["position"]["x"] == 1000
assert fixer.fixes_applied == []
class TestFixAddToDictionaryBlocks:
"""Tests for fix_addtodictionary_blocks."""
def test_removes_create_dictionary_nodes(self):
fixer = AgentFixer()
create_dict_id = "b924ddf4-de4f-4b56-9a85-358930dcbc91"
dict_node = _make_node(node_id="dict-1", block_id=create_dict_id)
add_to_dict_node = _make_node(
node_id="add-1", block_id=_ADDTODICTIONARY_BLOCK_ID
)
link = _make_link(source_id="dict-1", sink_id="add-1")
agent = _make_agent(nodes=[dict_node, add_to_dict_node], links=[link])
result = fixer.fix_addtodictionary_blocks(agent)
node_ids = [n["id"] for n in result["nodes"]]
assert "dict-1" not in node_ids
assert "add-1" in node_ids
assert len(result["links"]) == 0
class TestFixStoreValueBeforeCondition:
"""Tests for fix_storevalue_before_condition."""
def test_inserts_storevalue_block(self):
fixer = AgentFixer()
condition_block_id = "715696a0-e1da-45c8-b209-c2fa9c3b0be6"
src_node = _make_node(node_id="src")
cond_node = _make_node(node_id="cond", block_id=condition_block_id)
link = _make_link(
source_id="src", source_name="output", sink_id="cond", sink_name="value2"
)
agent = _make_agent(nodes=[src_node, cond_node], links=[link])
result = fixer.fix_storevalue_before_condition(agent)
# Should have 3 nodes now (original 2 + new StoreValueBlock)
assert len(result["nodes"]) == 3
store_nodes = [
n for n in result["nodes"] if n["block_id"] == _STORE_VALUE_BLOCK_ID
]
assert len(store_nodes) == 1
assert store_nodes[0]["input_default"]["data"] is None
class TestFixAddToListBlocks:
"""Tests for fix_addtolist_blocks - self-reference links."""
def test_addtolist_gets_self_reference_link(self):
fixer = AgentFixer()
node = _make_node(node_id="atl-1", block_id=_ADDTOLIST_BLOCK_ID)
# Source link to AddToList (from some other node)
link = _make_link(
source_id="other",
source_name="output",
sink_id="atl-1",
sink_name="item",
)
other_node = _make_node(node_id="other")
agent = _make_agent(nodes=[other_node, node], links=[link])
result = fixer.fix_addtolist_blocks(agent)
# Should have a self-reference link: atl-1.updated_list -> atl-1.list
self_ref_links = [
lnk
for lnk in result["links"]
if lnk["source_id"] == "atl-1"
and lnk["sink_id"] == "atl-1"
and lnk["source_name"] == "updated_list"
and lnk["sink_name"] == "list"
]
assert len(self_ref_links) == 1
class TestFixLinkStaticProperties:
"""Tests for fix_link_static_properties."""
def test_sets_is_static_from_block_schema(self):
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(node_id="n1", block_id=block_id)
link = _make_link(source_id="n1", sink_id="n2", is_static=False)
agent = _make_agent(nodes=[node], links=[link])
blocks = [{"id": block_id, "staticOutput": True}]
result = fixer.fix_link_static_properties(agent, blocks)
assert result["links"][0]["is_static"] is True
def test_unknown_block_leaves_link_unchanged(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id="unknown-block")
link = _make_link(source_id="n1", sink_id="n2", is_static=True)
agent = _make_agent(nodes=[node], links=[link])
result = fixer.fix_link_static_properties(agent, blocks=[])
# Unknown block → skipped, link stays as-is
assert result["links"][0]["is_static"] is True
class TestFixAiModelParameter:
"""Tests for fix_ai_model_parameter."""
def test_missing_model_gets_default(self):
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(node_id="n1", block_id=block_id, input_default={})
agent = _make_agent(nodes=[node])
blocks = [
{
"id": block_id,
"categories": [{"category": "AI"}],
"inputSchema": {
"properties": {"model": {"type": "string"}},
},
}
]
result = fixer.fix_ai_model_parameter(agent, blocks)
assert result["nodes"][0]["input_default"]["model"] == "gpt-4o"
def test_valid_model_unchanged(self):
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(
node_id="n1",
block_id=block_id,
input_default={"model": "claude-opus-4-6"},
)
agent = _make_agent(nodes=[node])
blocks = [
{
"id": block_id,
"categories": [{"category": "AI"}],
"inputSchema": {
"properties": {"model": {"type": "string"}},
},
}
]
result = fixer.fix_ai_model_parameter(agent, blocks)
assert result["nodes"][0]["input_default"]["model"] == "claude-opus-4-6"
class TestFixAgentExecutorBlocks:
"""Tests for fix_agent_executor_blocks."""
def test_fills_schemas_from_library_agent(self):
fixer = AgentFixer()
lib_agent_id = generate_uuid()
node = _make_node(
node_id="n1",
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": lib_agent_id,
"graph_version": 1,
"user_id": "user-1",
},
)
agent = _make_agent(nodes=[node])
# Library agents use graph_id as the lookup key
library_agents = [
{
"graph_id": lib_agent_id,
"graph_version": 2,
"input_schema": {"field1": {"type": "string"}},
"output_schema": {"result": {"type": "string"}},
}
]
result = fixer.fix_agent_executor_blocks(agent, library_agents)
node_result = result["nodes"][0]["input_default"]
assert node_result["graph_version"] == 2
assert node_result["input_schema"] == {"field1": {"type": "string"}}
assert node_result["output_schema"] == {"result": {"type": "string"}}
class TestFixInvalidNestedSinkLinks:
"""Tests for fix_invalid_nested_sink_links."""
def test_removes_numeric_index_links(self):
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(node_id="n1", block_id=block_id)
link = _make_link(source_id="n2", sink_id="n1", sink_name="values_#_0")
agent = _make_agent(nodes=[node], links=[link])
blocks = [
{
"id": block_id,
"inputSchema": {"properties": {"values": {"type": "array"}}},
}
]
result = fixer.fix_invalid_nested_sink_links(agent, blocks)
assert len(result["links"]) == 0
def test_valid_nested_links_kept(self):
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(node_id="n1", block_id=block_id)
link = _make_link(source_id="n2", sink_id="n1", sink_name="values_#_name")
agent = _make_agent(nodes=[node], links=[link])
blocks = [
{
"id": block_id,
"inputSchema": {
"properties": {"values": {"type": "object"}},
},
}
]
result = fixer.fix_invalid_nested_sink_links(agent, blocks)
assert len(result["links"]) == 1
class TestApplyAllFixes:
"""Tests for apply_all_fixes orchestration."""
def test_is_sync(self):
"""apply_all_fixes should be a sync function."""
import inspect
assert not inspect.iscoroutinefunction(AgentFixer.apply_all_fixes)
def test_applies_multiple_fixes(self):
fixer = AgentFixer()
agent = _make_agent(
agent_id="bad-id",
nodes=[
_make_node(
block_id=_TEXT_REPLACE_BLOCK_ID,
input_default={"new": "", "credentials": {"key": "secret"}},
)
],
)
result = fixer.apply_all_fixes(agent)
# Agent ID should be fixed
assert result["id"] != "bad-id"
# Credentials should be removed
assert "credentials" not in result["nodes"][0]["input_default"]
# Text replace "new" should be space
assert result["nodes"][0]["input_default"]["new"] == " "
# Multiple fixes applied
assert len(fixer.fixes_applied) >= 3
def test_empty_agent_no_crash(self):
fixer = AgentFixer()
agent = _make_agent()
result = fixer.apply_all_fixes(agent)
assert "nodes" in result
assert "links" in result
def test_returns_deep_copy_behavior(self):
"""Fixer mutates in place — verify the same dict is returned."""
fixer = AgentFixer()
agent = _make_agent()
result = fixer.apply_all_fixes(agent)
assert result is agent
class TestFixMCPToolBlocks:
"""Tests for fix_mcp_tool_blocks."""
def test_adds_missing_tool_arguments(self):
fixer = AgentFixer()
node = _make_node(
node_id="n1",
block_id=MCP_TOOL_BLOCK_ID,
input_default={
"server_url": "https://mcp.example.com/sse",
"selected_tool": "search",
"tool_input_schema": {},
},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_mcp_tool_blocks(agent)
assert result["nodes"][0]["input_default"]["tool_arguments"] == {}
assert any("tool_arguments" in f for f in fixer.fixes_applied)
def test_adds_missing_tool_input_schema(self):
fixer = AgentFixer()
node = _make_node(
node_id="n1",
block_id=MCP_TOOL_BLOCK_ID,
input_default={
"server_url": "https://mcp.example.com/sse",
"selected_tool": "search",
"tool_arguments": {},
},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_mcp_tool_blocks(agent)
assert result["nodes"][0]["input_default"]["tool_input_schema"] == {}
assert any("tool_input_schema" in f for f in fixer.fixes_applied)
def test_populates_tool_arguments_from_schema(self):
fixer = AgentFixer()
node = _make_node(
node_id="n1",
block_id=MCP_TOOL_BLOCK_ID,
input_default={
"server_url": "https://mcp.example.com/sse",
"selected_tool": "search",
"tool_input_schema": {
"properties": {
"query": {"type": "string", "default": "hello"},
"limit": {"type": "integer"},
}
},
"tool_arguments": {},
},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_mcp_tool_blocks(agent)
tool_args = result["nodes"][0]["input_default"]["tool_arguments"]
assert tool_args["query"] == "hello"
assert tool_args["limit"] is None
def test_no_op_when_already_complete(self):
fixer = AgentFixer()
node = _make_node(
node_id="n1",
block_id=MCP_TOOL_BLOCK_ID,
input_default={
"server_url": "https://mcp.example.com/sse",
"selected_tool": "search",
"tool_input_schema": {},
"tool_arguments": {},
},
)
agent = _make_agent(nodes=[node])
fixer.fix_mcp_tool_blocks(agent)
assert len(fixer.fixes_applied) == 0
class TestFixDynamicBlockSinkNames:
"""Tests for fix_dynamic_block_sink_names."""
def test_mcp_tool_arguments_prefix_removed(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id=MCP_TOOL_BLOCK_ID)
link = _make_link(
source_id="src", sink_id="n1", sink_name="tool_arguments_#_query"
)
agent = _make_agent(nodes=[node], links=[link])
fixer.fix_dynamic_block_sink_names(agent)
assert agent["links"][0]["sink_name"] == "query"
assert len(fixer.fixes_applied) == 1
def test_agent_executor_inputs_prefix_removed(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id=AGENT_EXECUTOR_BLOCK_ID)
link = _make_link(source_id="src", sink_id="n1", sink_name="inputs_#_url")
agent = _make_agent(nodes=[node], links=[link])
fixer.fix_dynamic_block_sink_names(agent)
assert agent["links"][0]["sink_name"] == "url"
assert len(fixer.fixes_applied) == 1
def test_bare_sink_name_unchanged(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id=MCP_TOOL_BLOCK_ID)
link = _make_link(source_id="src", sink_id="n1", sink_name="query")
agent = _make_agent(nodes=[node], links=[link])
fixer.fix_dynamic_block_sink_names(agent)
assert agent["links"][0]["sink_name"] == "query"
assert len(fixer.fixes_applied) == 0
def test_non_dynamic_block_unchanged(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id="some-other-block-id")
link = _make_link(source_id="src", sink_id="n1", sink_name="values_#_key")
agent = _make_agent(nodes=[node], links=[link])
fixer.fix_dynamic_block_sink_names(agent)
assert agent["links"][0]["sink_name"] == "values_#_key"
assert len(fixer.fixes_applied) == 0
class TestFixDataTypeMismatch:
"""Tests for fix_data_type_mismatch."""
@staticmethod
def _make_block(
block_id: str,
name: str = "TestBlock",
input_schema: dict | None = None,
output_schema: dict | None = None,
) -> dict:
return {
"id": block_id,
"name": name,
"inputSchema": input_schema or {"properties": {}},
"outputSchema": output_schema or {"properties": {}},
}
def test_inserts_converter_for_incompatible_types(self):
fixer = AgentFixer()
src_block_id = generate_uuid()
sink_block_id = generate_uuid()
src_node = _make_node(node_id="src", block_id=src_block_id)
sink_node = _make_node(node_id="sink", block_id=sink_block_id)
link = _make_link(
source_id="src",
source_name="result",
sink_id="sink",
sink_name="count",
)
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
blocks = [
self._make_block(
src_block_id,
name="Source",
output_schema={"properties": {"result": {"type": "string"}}},
),
self._make_block(
sink_block_id,
name="Sink",
input_schema={"properties": {"count": {"type": "integer"}}},
),
]
result = fixer.fix_data_type_mismatch(agent, blocks)
# A converter node should have been inserted
converter_nodes = [
n
for n in result["nodes"]
if n["block_id"] == _UNIVERSAL_TYPE_CONVERTER_BLOCK_ID
]
assert len(converter_nodes) == 1
assert converter_nodes[0]["input_default"]["type"] == "number"
# Original link replaced by two new links through the converter
assert len(result["links"]) == 2
src_to_converter = result["links"][0]
converter_to_sink = result["links"][1]
assert src_to_converter["source_id"] == "src"
assert src_to_converter["sink_id"] == converter_nodes[0]["id"]
assert src_to_converter["sink_name"] == "value"
assert converter_to_sink["source_id"] == converter_nodes[0]["id"]
assert converter_to_sink["source_name"] == "value"
assert converter_to_sink["sink_id"] == "sink"
assert converter_to_sink["sink_name"] == "count"
assert len(fixer.fixes_applied) == 1
def test_compatible_types_unchanged(self):
fixer = AgentFixer()
block_id = generate_uuid()
src_node = _make_node(node_id="src", block_id=block_id)
sink_node = _make_node(node_id="sink", block_id=block_id)
link = _make_link(
source_id="src",
source_name="output",
sink_id="sink",
sink_name="input",
)
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
blocks = [
self._make_block(
block_id,
input_schema={"properties": {"input": {"type": "string"}}},
output_schema={"properties": {"output": {"type": "string"}}},
),
]
result = fixer.fix_data_type_mismatch(agent, blocks)
# No converter inserted, original link kept
assert len(result["nodes"]) == 2
assert len(result["links"]) == 1
assert result["links"][0] is link
assert fixer.fixes_applied == []
def test_missing_block_keeps_link(self):
"""Links referencing unknown blocks are kept unchanged."""
fixer = AgentFixer()
src_node = _make_node(node_id="src", block_id="unknown-block")
sink_node = _make_node(node_id="sink", block_id="unknown-block")
link = _make_link(source_id="src", sink_id="sink")
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
result = fixer.fix_data_type_mismatch(agent, blocks=[])
assert len(result["links"]) == 1
assert result["links"][0] is link
def test_missing_type_info_keeps_link(self):
"""Links where source/sink type is not defined are kept unchanged."""
fixer = AgentFixer()
block_id = generate_uuid()
src_node = _make_node(node_id="src", block_id=block_id)
sink_node = _make_node(node_id="sink", block_id=block_id)
link = _make_link(
source_id="src",
source_name="output",
sink_id="sink",
sink_name="input",
)
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
# Block has no properties defined for the linked fields
blocks = [self._make_block(block_id)]
result = fixer.fix_data_type_mismatch(agent, blocks)
assert len(result["links"]) == 1
assert fixer.fixes_applied == []
def test_multiple_mismatches_insert_multiple_converters(self):
"""Each incompatible link gets its own converter node."""
fixer = AgentFixer()
src_block_id = generate_uuid()
sink_block_id = generate_uuid()
src_node = _make_node(node_id="src", block_id=src_block_id)
sink1 = _make_node(node_id="sink1", block_id=sink_block_id)
sink2 = _make_node(node_id="sink2", block_id=sink_block_id)
link1 = _make_link(
source_id="src", source_name="out", sink_id="sink1", sink_name="count"
)
link2 = _make_link(
source_id="src", source_name="out", sink_id="sink2", sink_name="count"
)
agent = _make_agent(nodes=[src_node, sink1, sink2], links=[link1, link2])
blocks = [
self._make_block(
src_block_id,
output_schema={"properties": {"out": {"type": "string"}}},
),
self._make_block(
sink_block_id,
input_schema={"properties": {"count": {"type": "integer"}}},
),
]
result = fixer.fix_data_type_mismatch(agent, blocks)
converter_nodes = [
n
for n in result["nodes"]
if n["block_id"] == _UNIVERSAL_TYPE_CONVERTER_BLOCK_ID
]
assert len(converter_nodes) == 2
# Each original link becomes two links through its own converter
assert len(result["links"]) == 4
assert len(fixer.fixes_applied) == 2

View File

@@ -0,0 +1,67 @@
"""Shared helpers for agent generation."""
import re
import uuid
from typing import Any
from .blocks import get_blocks_as_dicts
__all__ = [
"AGENT_EXECUTOR_BLOCK_ID",
"AGENT_INPUT_BLOCK_ID",
"AGENT_OUTPUT_BLOCK_ID",
"AgentDict",
"MCP_TOOL_BLOCK_ID",
"UUID_REGEX",
"are_types_compatible",
"generate_uuid",
"get_blocks_as_dicts",
"get_defined_property_type",
"is_uuid",
]
# Type alias for the agent JSON structure passed through
# the validation and fixing pipeline.
AgentDict = dict[str, Any]
# Shared base pattern (unanchored, lowercase hex); used for both full-string
# validation (UUID_REGEX) and text extraction (core._UUID_PATTERN).
UUID_RE_STR = r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[a-f0-9]{4}-[a-f0-9]{12}"
UUID_REGEX = re.compile(r"^" + UUID_RE_STR + r"$")
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
MCP_TOOL_BLOCK_ID = "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
def is_uuid(value: str) -> bool:
"""Check if a string is a valid UUID."""
return isinstance(value, str) and UUID_REGEX.match(value) is not None
def generate_uuid() -> str:
"""Generate a new UUID string."""
return str(uuid.uuid4())
def get_defined_property_type(schema: dict[str, Any], name: str) -> str | None:
"""Get property type from a schema, handling nested `_#_` notation."""
if "_#_" in name:
parent, child = name.split("_#_", 1)
parent_schema = schema.get(parent, {})
if "properties" in parent_schema and isinstance(
parent_schema["properties"], dict
):
return parent_schema["properties"].get(child, {}).get("type")
return None
return schema.get(name, {}).get("type")
def are_types_compatible(src: str, sink: str) -> bool:
"""Check if two schema types are compatible."""
if {src, sink} <= {"integer", "number"}:
return True
return src == sink

View File

@@ -0,0 +1,196 @@
"""Shared fix → validate → preview/save pipeline for agent tools."""
import json
import logging
from typing import Any, cast
from backend.copilot.tools.models import (
AgentPreviewResponse,
AgentSavedResponse,
ErrorResponse,
ToolResponseBase,
)
from .blocks import get_blocks_as_dicts
from .core import get_library_agents_by_ids, save_agent_to_library
from .fixer import AgentFixer
from .validator import AgentValidator
logger = logging.getLogger(__name__)
MAX_AGENT_JSON_SIZE = 1_000_000 # 1 MB
async def fetch_library_agents(
user_id: str | None,
library_agent_ids: list[str],
) -> list[dict[str, Any]] | None:
"""Fetch library agents by IDs for AgentExecutorBlock validation.
Returns None if no IDs provided or user is not authenticated.
"""
if not user_id or not library_agent_ids:
return None
try:
agents = await get_library_agents_by_ids(
user_id=user_id,
agent_ids=library_agent_ids,
)
return cast(list[dict[str, Any]], agents)
except Exception as e:
logger.warning(f"Failed to fetch library agents by IDs: {e}")
return None
async def fix_validate_and_save(
agent_json: dict[str, Any],
*,
user_id: str | None,
session_id: str | None,
save: bool = True,
is_update: bool = False,
default_name: str = "Agent",
preview_message: str | None = None,
save_message: str | None = None,
library_agents: list[dict[str, Any]] | None = None,
folder_id: str | None = None,
) -> ToolResponseBase:
"""Shared pipeline: auto-fix → validate → preview or save.
Args:
agent_json: The agent JSON dict (must already have id/version/is_active set).
user_id: The authenticated user's ID.
session_id: The chat session ID.
save: Whether to save or just preview.
is_update: Whether this is an update to an existing agent.
default_name: Fallback name if agent_json has none.
preview_message: Custom preview message (optional).
save_message: Custom save success message (optional).
library_agents: Library agents for AgentExecutorBlock validation/fixing.
Returns:
An appropriate ToolResponseBase subclass.
"""
# Size guard
json_size = len(json.dumps(agent_json))
if json_size > MAX_AGENT_JSON_SIZE:
return ErrorResponse(
message=(
f"Agent JSON is too large ({json_size:,} bytes, "
f"max {MAX_AGENT_JSON_SIZE:,}). Reduce the number of nodes."
),
error="agent_json_too_large",
session_id=session_id,
)
blocks = get_blocks_as_dicts()
# Auto-fix
try:
fixer = AgentFixer()
agent_json = fixer.apply_all_fixes(agent_json, blocks, library_agents)
fixes = fixer.get_fixes_applied()
if fixes:
logger.info(f"Applied {len(fixes)} auto-fixes to agent JSON")
except Exception as e:
logger.warning(f"Auto-fix failed: {e}")
# Validate
try:
validator = AgentValidator()
is_valid, _ = validator.validate(agent_json, blocks, library_agents)
if not is_valid:
errors = validator.errors
return ErrorResponse(
message=(
f"The agent has {len(errors)} validation error(s):\n"
+ "\n".join(f"- {e}" for e in errors[:5])
),
error="validation_failed",
details={"errors": errors},
session_id=session_id,
)
except Exception as e:
logger.error(f"Validation failed with exception: {e}", exc_info=True)
return ErrorResponse(
message="Failed to validate the agent. Please try again.",
error="validation_exception",
details={"exception": str(e)},
session_id=session_id,
)
agent_name = agent_json.get("name", default_name)
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
# Build a warning suffix when name/description is missing or generic
_GENERIC_NAMES = {
"agent",
"generated agent",
"customized agent",
"updated agent",
"new agent",
"my agent",
}
metadata_warnings: list[str] = []
if not agent_json.get("name") or agent_name.lower().strip() in _GENERIC_NAMES:
metadata_warnings.append("'name'")
if not agent_description:
metadata_warnings.append("'description'")
metadata_hint = ""
if metadata_warnings:
missing = " and ".join(metadata_warnings)
metadata_hint = (
f" Note: the agent is missing a meaningful {missing}. "
f"Please update the agent_json to include them."
)
if not save:
return AgentPreviewResponse(
message=(
(
preview_message
or f"Agent '{agent_name}' with {node_count} blocks is ready."
)
+ metadata_hint
),
agent_json=agent_json,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
try:
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id, is_update=is_update, folder_id=folder_id
)
return AgentSavedResponse(
message=(
(save_message or f"Agent '{created_graph.name}' has been saved!")
+ metadata_hint
),
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
logger.error(f"Failed to save agent: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to save the agent: {str(e)}",
error="save_failed",
details={"exception": str(e)},
session_id=session_id,
)

View File

@@ -1,511 +0,0 @@
"""External Agent Generator service client.
This module provides a client for communicating with the external Agent Generator
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
will delegate to the external service instead of using the built-in LLM-based implementation.
"""
import logging
from typing import Any
import httpx
from backend.util.settings import Settings
from .dummy import (
customize_template_dummy,
decompose_goal_dummy,
generate_agent_dummy,
generate_agent_patch_dummy,
get_blocks_dummy,
health_check_dummy,
)
logger = logging.getLogger(__name__)
_dummy_mode_warned = False
def _create_error_response(
error_message: str,
error_type: str = "unknown",
details: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Create a standardized error response dict.
Args:
error_message: Human-readable error message
error_type: Machine-readable error type
details: Optional additional error details
Returns:
Error dict with type="error" and error details
"""
response: dict[str, Any] = {
"type": "error",
"error": error_message,
"error_type": error_type,
}
if details:
response["details"] = details
return response
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
"""Classify an HTTP error into error_type and message.
Args:
e: The HTTP status error
Returns:
Tuple of (error_type, error_message)
"""
status = e.response.status_code
if status == 429:
return "rate_limit", f"Agent Generator rate limited: {e}"
elif status == 503:
return "service_unavailable", f"Agent Generator unavailable: {e}"
elif status == 504 or status == 408:
return "timeout", f"Agent Generator timed out: {e}"
else:
return "http_error", f"HTTP error calling Agent Generator: {e}"
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
"""Classify a request error into error_type and message.
Args:
e: The request error
Returns:
Tuple of (error_type, error_message)
"""
error_str = str(e).lower()
if "timeout" in error_str or "timed out" in error_str:
return "timeout", f"Agent Generator request timed out: {e}"
elif "connect" in error_str:
return "connection_error", f"Could not connect to Agent Generator: {e}"
else:
return "request_error", f"Request error calling Agent Generator: {e}"
_client: httpx.AsyncClient | None = None
_settings: Settings | None = None
def _get_settings() -> Settings:
"""Get or create settings singleton."""
global _settings
if _settings is None:
_settings = Settings()
return _settings
def _is_dummy_mode() -> bool:
"""Check if dummy mode is enabled for testing."""
global _dummy_mode_warned
settings = _get_settings()
is_dummy = bool(settings.config.agentgenerator_use_dummy)
if is_dummy and not _dummy_mode_warned:
logger.warning(
"Agent Generator running in DUMMY MODE - returning mock responses. "
"Do not use in production!"
)
_dummy_mode_warned = True
return is_dummy
def is_external_service_configured() -> bool:
"""Check if external Agent Generator service is configured (or dummy mode)."""
settings = _get_settings()
return bool(settings.config.agentgenerator_host) or bool(
settings.config.agentgenerator_use_dummy
)
def _get_base_url() -> str:
"""Get the base URL for the external service."""
settings = _get_settings()
host = settings.config.agentgenerator_host
port = settings.config.agentgenerator_port
return f"http://{host}:{port}"
def _get_client() -> httpx.AsyncClient:
"""Get or create the HTTP client for the external service."""
global _client
if _client is None:
settings = _get_settings()
_client = httpx.AsyncClient(
base_url=_get_base_url(),
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
)
return _client
async def decompose_goal_external(
description: str,
context: str = "",
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Call the external service to decompose a goal.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
library_agents: User's library agents available for sub-agent composition
Returns:
Dict with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
- {"type": "unachievable_goal", ...}
- {"type": "vague_goal", ...}
- {"type": "error", "error": "...", "error_type": "..."} on error
Or None on unexpected error
"""
if _is_dummy_mode():
return await decompose_goal_dummy(description, context, library_agents)
client = _get_client()
if context:
description = f"{description}\n\nAdditional context from user:\n{context}"
payload: dict[str, Any] = {"description": description}
if library_agents:
payload["library_agents"] = library_agents
try:
response = await client.post("/api/decompose-description", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator decomposition failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Map the response to the expected format
response_type = data.get("type")
if response_type == "instructions":
return {"type": "instructions", "steps": data.get("steps", [])}
elif response_type == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
elif response_type == "unachievable_goal":
return {
"type": "unachievable_goal",
"reason": data.get("reason"),
"suggested_goal": data.get("suggested_goal"),
}
elif response_type == "vague_goal":
return {
"type": "vague_goal",
"suggested_goal": data.get("suggested_goal"),
}
elif response_type == "error":
# Pass through error from the service
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
else:
logger.error(
f"Unknown response type from external service: {response_type}"
)
return _create_error_response(
f"Unknown response type from Agent Generator: {response_type}",
"invalid_response",
)
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def generate_agent_external(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
Returns:
Agent JSON dict or error dict {"type": "error", ...} on error
"""
if _is_dummy_mode():
return await generate_agent_dummy(instructions, library_agents)
client = _get_client()
# Build request payload
payload: dict[str, Any] = {"instructions": instructions}
if library_agents:
payload["library_agents"] = library_agents
try:
response = await client.post("/api/generate-agent", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
)
return _create_error_response(error_msg, error_type)
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def generate_agent_patch_external(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
session_id: Session ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
"""
if _is_dummy_mode():
return await generate_agent_patch_dummy(
update_request, current_agent, library_agents
)
client = _get_client()
# Build request payload
payload: dict[str, Any] = {
"update_request": update_request,
"current_agent_json": current_agent,
}
if library_agents:
payload["library_agents"] = library_agents
try:
response = await client.post("/api/update-agent", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator patch generation failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the updated agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def customize_template_external(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Call the external service to customize a template/marketplace agent.
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
operation_id: Operation ID for async processing (enables Redis Streams callback)
session_id: Session ID for async processing (enables Redis Streams callback)
Returns:
Customized agent JSON, clarifying questions dict, or error dict on error
"""
if _is_dummy_mode():
return await customize_template_dummy(
template_agent, modification_request, context
)
client = _get_client()
request = modification_request
if context:
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
payload: dict[str, Any] = {
"template_agent_json": template_agent,
"modification_request": request,
}
try:
response = await client.post("/api/template-modification", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator template customization failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the customized agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def get_blocks_external() -> list[dict[str, Any]] | None:
"""Get available blocks from the external service.
Returns:
List of block info dicts or None on error
"""
if _is_dummy_mode():
return await get_blocks_dummy()
client = _get_client()
try:
response = await client.get("/api/blocks")
response.raise_for_status()
data = response.json()
if not data.get("success"):
logger.error("External service returned error getting blocks")
return None
return data.get("blocks", [])
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error getting blocks from external service: {e}")
return None
except httpx.RequestError as e:
logger.error(f"Request error getting blocks from external service: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error getting blocks from external service: {e}")
return None
async def health_check() -> bool:
"""Check if the external service is healthy.
Returns:
True if healthy, False otherwise
"""
if not is_external_service_configured():
return False
if _is_dummy_mode():
return await health_check_dummy()
client = _get_client()
try:
response = await client.get("/health")
response.raise_for_status()
data = response.json()
return data.get("status") == "healthy" and data.get("blocks_loaded", False)
except Exception as e:
logger.warning(f"External agent generator health check failed: {e}")
return False
async def close_client() -> None:
"""Close the HTTP client."""
global _client
if _client is not None:
await _client.aclose()
_client = None

View File

@@ -0,0 +1,17 @@
"""Agent generation validation — re-exports from split modules.
This module was split into:
- helpers.py: get_blocks_as_dicts, block cache
- fixer.py: AgentFixer class
- validator.py: AgentValidator class
"""
from .fixer import AgentFixer
from .helpers import get_blocks_as_dicts
from .validator import AgentValidator
__all__ = [
"AgentFixer",
"AgentValidator",
"get_blocks_as_dicts",
]

View File

@@ -0,0 +1,939 @@
"""AgentValidator — validates agent JSON graphs for correctness."""
import json
import logging
import re
from typing import Any
from .helpers import (
AGENT_EXECUTOR_BLOCK_ID,
AGENT_INPUT_BLOCK_ID,
AGENT_OUTPUT_BLOCK_ID,
MCP_TOOL_BLOCK_ID,
AgentDict,
are_types_compatible,
get_defined_property_type,
)
logger = logging.getLogger(__name__)
class AgentValidator:
"""
A comprehensive validator for AutoGPT agents that provides detailed error
reporting for LLM-based fixes.
"""
def __init__(self):
self.errors: list[str] = []
def add_error(self, error_message: str) -> None:
"""Add an error message to the validation errors list."""
self.errors.append(error_message)
def _values_equal(self, val1: Any, val2: Any) -> bool:
"""Compare two values, handling complex types like dicts and lists."""
if type(val1) is not type(val2):
return False
if isinstance(val1, dict):
return json.dumps(val1, sort_keys=True) == json.dumps(val2, sort_keys=True)
if isinstance(val1, list):
return json.dumps(val1, sort_keys=True) == json.dumps(val2, sort_keys=True)
return val1 == val2
def validate_block_existence(
self, agent: AgentDict, blocks: list[dict[str, Any]]
) -> bool:
"""
Validate that all block IDs used in the agent actually exist in the
blocks list. Returns True if all block IDs exist, False otherwise.
"""
valid = True
# Create a set of all valid block IDs for fast lookup
valid_block_ids = {block.get("id") for block in blocks if block.get("id")}
# Check each node's block_id
for node in agent.get("nodes", []):
block_id = node.get("block_id")
node_id = node.get("id")
if not block_id:
self.add_error(
f"Node '{node_id}' is missing a 'block_id' field. "
f"Every node must reference a valid block."
)
valid = False
continue
if block_id not in valid_block_ids:
self.add_error(
f"Node '{node_id}' references block_id '{block_id}' "
f"which does not exist in the available blocks. "
f"This block may have been deprecated, removed, or "
f"the ID is incorrect. Please use a valid block from "
f"the blocks library."
)
valid = False
return valid
def validate_link_node_references(self, agent: AgentDict) -> bool:
"""
Validate that all node IDs referenced in links actually exist in the
agent's nodes. Returns True if all link references are valid, False
otherwise.
"""
valid = True
# Create a set of all valid node IDs for fast lookup
valid_node_ids = {
node.get("id") for node in agent.get("nodes", []) if node.get("id")
}
# Check each link's source_id and sink_id
for link in agent.get("links", []):
link_id = link.get("id", "Unknown")
source_id = link.get("source_id")
sink_id = link.get("sink_id")
source_name = link.get("source_name", "")
sink_name = link.get("sink_name", "")
# Check source_id
if not source_id:
self.add_error(
f"Link '{link_id}' is missing a 'source_id' field. "
f"Every link must reference a valid source node."
)
valid = False
elif source_id not in valid_node_ids:
self.add_error(
f"Link '{link_id}' references source_id '{source_id}' "
f"which does not exist in the agent's nodes. The link "
f"from '{source_name}' cannot be established because "
f"the source node is missing."
)
valid = False
# Check sink_id
if not sink_id:
self.add_error(
f"Link '{link_id}' is missing a 'sink_id' field. "
f"Every link must reference a valid sink (destination) "
f"node."
)
valid = False
elif sink_id not in valid_node_ids:
self.add_error(
f"Link '{link_id}' references sink_id '{sink_id}' "
f"which does not exist in the agent's nodes. The link "
f"to '{sink_name}' cannot be established because the "
f"destination node is missing."
)
valid = False
return valid
def validate_required_inputs(
self, agent: AgentDict, blocks: list[dict[str, Any]]
) -> bool:
"""
Validate that all required inputs are provided for each node.
Returns True if all required inputs are satisfied, False otherwise.
"""
valid = True
block_lookup = {b.get("id", ""): b for b in blocks}
for node in agent.get("nodes", []):
block_id = node.get("block_id")
block = block_lookup.get(block_id)
if not block:
continue
required_inputs = block.get("inputSchema", {}).get("required", [])
input_defaults = node.get("input_default", {})
node_id = node.get("id")
linked_inputs = set(
link.get("sink_name")
for link in agent.get("links", [])
if link.get("sink_id") == node_id and link.get("sink_name")
)
for req_input in required_inputs:
if (
req_input not in input_defaults
and req_input not in linked_inputs
and req_input != "credentials"
):
block_name = block.get("name", "Unknown Block")
self.add_error(
f"Node '{node_id}' (block '{block_name}' - "
f"{block_id}) is missing required input "
f"'{req_input}'. This input must be either "
f"provided as a default value in the node's "
f"'input_default' field or connected via a link "
f"from another node's output."
)
valid = False
return valid
def validate_data_type_compatibility(
self, agent: AgentDict, blocks: list[dict[str, Any]]
) -> bool:
"""
Validate that linked data types are compatible between source and sink.
Returns True if all data types are compatible, False otherwise.
"""
valid = True
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
block_lookup = {block.get("id", ""): block for block in blocks}
for link in agent.get("links", []):
source_id = link.get("source_id")
sink_id = link.get("sink_id")
source_name = link.get("source_name")
sink_name = link.get("sink_name")
if not all(
isinstance(v, str) and v
for v in (source_id, sink_id, source_name, sink_name)
):
self.add_error(
f"Link '{link.get('id', 'Unknown')}' is missing required "
f"fields (source_id/sink_id/source_name/sink_name)."
)
valid = False
continue
source_node = node_lookup.get(source_id, "")
sink_node = node_lookup.get(sink_id, "")
if not source_node or not sink_node:
continue
source_block = block_lookup.get(source_node.get("block_id", ""))
sink_block = block_lookup.get(sink_node.get("block_id", ""))
if not source_block or not sink_block:
continue
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
source_type = get_defined_property_type(source_outputs, source_name)
sink_type = get_defined_property_type(sink_inputs, sink_name)
if (
source_type
and sink_type
and not are_types_compatible(source_type, sink_type)
):
source_block_name = source_block.get("name", "Unknown Block")
sink_block_name = sink_block.get("name", "Unknown Block")
self.add_error(
f"Data type mismatch in link '{link.get('id')}': "
f"Source '{source_block_name}' output "
f"'{link.get('source_name', '')}' outputs '{source_type}' "
f"type, but sink '{sink_block_name}' input "
f"'{link.get('sink_name', '')}' expects '{sink_type}' type. "
f"These types must match for the connection to work "
f"properly."
)
valid = False
return valid
def validate_nested_sink_links(
self, agent: AgentDict, blocks: list[dict[str, Any]]
) -> bool:
"""
Validate nested sink links (links with _#_ notation).
Returns True if all nested links are valid, False otherwise.
"""
valid = True
block_input_schemas = {
block.get("id", ""): block.get("inputSchema", {}).get("properties", {})
for block in blocks
}
block_names = {
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
}
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
for link in agent.get("links", []):
sink_name = link.get("sink_name", "")
sink_id = link.get("sink_id")
if not sink_name or not sink_id:
continue
if "_#_" in sink_name:
parent, child = sink_name.split("_#_", 1)
sink_node = node_lookup.get(sink_id)
if not sink_node:
continue
block_id = sink_node.get("block_id")
input_props = block_input_schemas.get(block_id, {})
parent_schema = input_props.get(parent)
if not parent_schema:
block_name = block_names.get(block_id, "Unknown Block")
self.add_error(
f"Invalid nested sink link '{sink_name}' for "
f"node '{sink_id}' (block "
f"'{block_name}' - {block_id}): Parent property "
f"'{parent}' does not exist in the block's "
f"input schema."
)
valid = False
continue
# Check if additionalProperties is allowed either directly
# or via anyOf
allows_additional_properties = parent_schema.get(
"additionalProperties", False
)
# Check anyOf for additionalProperties
if not allows_additional_properties and "anyOf" in parent_schema:
any_of_schemas = parent_schema.get("anyOf", [])
if isinstance(any_of_schemas, list):
for schema_option in any_of_schemas:
if isinstance(schema_option, dict) and schema_option.get(
"additionalProperties"
):
allows_additional_properties = True
break
if not allows_additional_properties:
if not (
isinstance(parent_schema, dict)
and "properties" in parent_schema
and isinstance(parent_schema["properties"], dict)
and child in parent_schema["properties"]
):
block_name = block_names.get(block_id, "Unknown Block")
self.add_error(
f"Invalid nested sink link '{sink_name}' "
f"for node '{link.get('sink_id', '')}' (block "
f"'{block_name}' - {block_id}): Child "
f"property '{child}' does not exist in "
f"parent '{parent}' schema. Available "
f"properties: "
f"{list(parent_schema.get('properties', {}).keys())}"
)
valid = False
return valid
def validate_prompt_double_curly_braces_spaces(self, agent: AgentDict) -> bool:
"""
Validate that prompt parameters do not contain spaces in double curly
braces.
Checks the 'prompt' parameter in input_default of each node and reports
errors if values within double curly braces ({{...}}) contain spaces.
For example, {{user name}} should be {{user_name}}.
Args:
agent: The agent dictionary to validate
Returns:
True if all prompts are valid (no spaces in double curly braces),
False otherwise
"""
valid = True
nodes = agent.get("nodes", [])
for node in nodes:
node_id = node.get("id")
input_default = node.get("input_default", {})
# Check if 'prompt' parameter exists
if "prompt" not in input_default:
continue
prompt_text = input_default["prompt"]
# Only process if it's a string
if not isinstance(prompt_text, str):
continue
# Find all double curly brace patterns with spaces
matches = re.finditer(r"\{\{([^}]+)\}\}", prompt_text)
for match in matches:
content = match.group(1)
if " " in content:
start_pos = match.start()
snippet_start = max(0, start_pos - 30)
snippet_end = min(len(prompt_text), match.end() + 30)
snippet = prompt_text[snippet_start:snippet_end]
self.add_error(
f"Node '{node_id}' has spaces in double curly "
f"braces in prompt parameter: "
f"'{{{{{content}}}}}' should be "
f"'{{{{{content.replace(' ', '_')}}}}}'. "
f"Context: ...{snippet}..."
)
valid = False
return valid
def validate_source_output_existence(
self, agent: AgentDict, blocks: list[dict[str, Any]]
) -> bool:
"""
Validate that all source_names in links exist in the corresponding
block's output schema.
Checks that for each link, the source_name field references a valid
output property in the source block's outputSchema. Also handles nested
outputs with _#_ notation.
Args:
agent: The agent dictionary to validate
blocks: List of available blocks with their schemas
Returns:
True if all source output fields exist, False otherwise
"""
valid = True
# Create lookup dictionaries for efficiency
block_output_schemas = {
block.get("id", ""): block.get("outputSchema", {}).get("properties", {})
for block in blocks
}
block_names = {
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
}
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
for link in agent.get("links", []):
source_id = link.get("source_id")
source_name = link.get("source_name", "")
link_id = link.get("id", "Unknown")
if not source_name:
self.add_error(
f"Link '{link_id}' is missing 'source_name'. "
f"Every link must specify which output field to read from."
)
valid = False
continue
source_node = node_lookup.get(source_id)
if not source_node:
# This error is already caught by
# validate_link_node_references
continue
block_id = source_node.get("block_id")
block_name = block_names.get(block_id, "Unknown Block")
# Special handling for AgentExecutorBlock - use dynamic
# output_schema from input_default
if block_id == AGENT_EXECUTOR_BLOCK_ID:
input_default = source_node.get("input_default", {})
dynamic_output_schema = input_default.get("output_schema", {})
if not isinstance(dynamic_output_schema, dict):
dynamic_output_schema = {}
output_props = dynamic_output_schema.get("properties", {})
if not isinstance(output_props, dict):
output_props = {}
else:
output_props = block_output_schemas.get(block_id, {})
# Handle nested source names (with _#_ notation)
if "_#_" in source_name:
parent, child = source_name.split("_#_", 1)
parent_schema = output_props.get(parent)
if not parent_schema:
self.add_error(
f"Invalid source output field '{source_name}' "
f"in link '{link_id}' from node '{source_id}' "
f"(block '{block_name}' - {block_id}): Parent "
f"property '{parent}' does not exist in the "
f"block's output schema."
)
valid = False
continue
# Check if additionalProperties is allowed either directly
# or via anyOf
allows_additional_properties = parent_schema.get(
"additionalProperties", False
)
if not allows_additional_properties and "anyOf" in parent_schema:
any_of_schemas = parent_schema.get("anyOf", [])
if isinstance(any_of_schemas, list):
for schema_option in any_of_schemas:
if isinstance(schema_option, dict) and schema_option.get(
"additionalProperties"
):
allows_additional_properties = True
break
# Also allow when items have
# additionalProperties (array of objects)
if (
isinstance(schema_option, dict)
and "items" in schema_option
):
items_schema = schema_option.get("items")
if isinstance(items_schema, dict) and items_schema.get(
"additionalProperties"
):
allows_additional_properties = True
break
# Only require child in properties when
# additionalProperties is not allowed
if not allows_additional_properties:
if not (
isinstance(parent_schema, dict)
and "properties" in parent_schema
and isinstance(parent_schema["properties"], dict)
and child in parent_schema["properties"]
):
available_props = (
list(parent_schema.get("properties", {}).keys())
if isinstance(parent_schema, dict)
else []
)
self.add_error(
f"Invalid nested source output field "
f"'{source_name}' in link '{link_id}' from "
f"node '{source_id}' (block "
f"'{block_name}' - {block_id}): Child "
f"property '{child}' does not exist in "
f"parent '{parent}' output schema. "
f"Available properties: {available_props}"
)
valid = False
else:
# Check simple (non-nested) source name
if source_name not in output_props:
available_outputs = list(output_props.keys())
self.add_error(
f"Invalid source output field '{source_name}' "
f"in link '{link_id}' from node '{source_id}' "
f"(block '{block_name}' - {block_id}): Output "
f"property '{source_name}' does not exist in "
f"the block's output schema. Available outputs: "
f"{available_outputs}"
)
valid = False
return valid
def validate_io_blocks(self, agent: AgentDict) -> bool:
"""
Validate that the agent has at least one AgentInputBlock and one
AgentOutputBlock. These blocks define the agent's interface.
Returns True if both are present, False otherwise.
"""
valid = True
block_ids = {node.get("block_id") for node in agent.get("nodes", [])}
if AGENT_INPUT_BLOCK_ID not in block_ids:
self.add_error(
f"Agent is missing an AgentInputBlock (block_id: "
f"'{AGENT_INPUT_BLOCK_ID}'). Every agent must have at "
f"least one AgentInputBlock to define user-facing inputs. "
f"Add a node with block_id '{AGENT_INPUT_BLOCK_ID}' and "
f"set input_default with 'name' and optionally 'title'."
)
valid = False
if AGENT_OUTPUT_BLOCK_ID not in block_ids:
self.add_error(
f"Agent is missing an AgentOutputBlock (block_id: "
f"'{AGENT_OUTPUT_BLOCK_ID}'). Every agent must have at "
f"least one AgentOutputBlock to define user-facing outputs. "
f"Add a node with block_id '{AGENT_OUTPUT_BLOCK_ID}' and "
f"set input_default with 'name', then link 'value' from "
f"another block's output."
)
valid = False
return valid
def validate_agent_executor_blocks(
self,
agent: AgentDict,
library_agents: list[dict[str, Any]] | None = None,
) -> bool:
"""
Validate AgentExecutorBlock nodes have required fields and valid
references.
Checks that AgentExecutorBlock nodes:
1. Have a valid graph_id in input_default (required)
2. If graph_id matches a known library agent, validates version
consistency
3. Sub-agent required inputs are connected via links (not hardcoded)
Note: Unknown graph_ids are not treated as errors - they could be valid
direct references to agents by their actual ID (not via library_agents).
This is consistent with fix_agent_executor_blocks() behavior.
Args:
agent: The agent dictionary to validate
library_agents: List of available library agents (for version
validation)
Returns:
True if all AgentExecutorBlock nodes are valid, False otherwise
"""
valid = True
nodes = agent.get("nodes", [])
links = agent.get("links", [])
# Create lookup for library agents
library_agent_lookup: dict[str, dict[str, Any]] = {}
if library_agents:
library_agent_lookup = {la.get("graph_id", ""): la for la in library_agents}
for node in nodes:
if node.get("block_id") != AGENT_EXECUTOR_BLOCK_ID:
continue
node_id = node.get("id")
input_default = node.get("input_default", {})
# Check for required graph_id
graph_id = input_default.get("graph_id")
if not graph_id:
self.add_error(
f"AgentExecutorBlock node '{node_id}' is missing "
f"required 'graph_id' in input_default. This field "
f"must reference the ID of the sub-agent to execute."
)
valid = False
continue
# If graph_id is not in library_agent_lookup, skip validation
if graph_id not in library_agent_lookup:
continue
# Validate version consistency for known library agents
library_agent = library_agent_lookup[graph_id]
expected_version = library_agent.get("graph_version")
current_version = input_default.get("graph_version")
if (
current_version
and expected_version
and current_version != expected_version
):
self.add_error(
f"AgentExecutorBlock node '{node_id}' has mismatched "
f"graph_version: got {current_version}, expected "
f"{expected_version} for library agent "
f"'{library_agent.get('name')}'"
)
valid = False
# Validate sub-agent inputs are properly linked (not hardcoded)
sub_agent_input_schema = library_agent.get("input_schema", {})
if not isinstance(sub_agent_input_schema, dict):
sub_agent_input_schema = {}
sub_agent_required_inputs = sub_agent_input_schema.get("required", [])
sub_agent_properties = sub_agent_input_schema.get("properties", {})
# Get all linked inputs to this node
linked_sub_agent_inputs: set[str] = set()
for link in links:
if link.get("sink_id") == node_id:
sink_name = link.get("sink_name", "")
if sink_name in sub_agent_properties:
linked_sub_agent_inputs.add(sink_name)
# Check for hardcoded inputs that should be linked
hardcoded_inputs = input_default.get("inputs", {})
input_schema = input_default.get("input_schema", {})
schema_properties = (
input_schema.get("properties", {})
if isinstance(input_schema, dict)
else {}
)
if isinstance(hardcoded_inputs, dict) and hardcoded_inputs:
for input_name, value in hardcoded_inputs.items():
if input_name not in sub_agent_properties:
continue
if value is None:
continue
# Skip if this input is already linked
if input_name in linked_sub_agent_inputs:
continue
prop_schema = schema_properties.get(input_name, {})
schema_default = (
prop_schema.get("default")
if isinstance(prop_schema, dict)
else None
)
if schema_default is not None and self._values_equal(
value, schema_default
):
continue
# This is a non-default hardcoded value without a link
self.add_error(
f"AgentExecutorBlock node '{node_id}' has "
f"hardcoded input '{input_name}'. Sub-agent "
f"inputs should be connected via links using "
f"'{input_name}' as sink_name, not hardcoded "
f"in input_default.inputs. Create a link from "
f"the appropriate source node."
)
valid = False
# Check for missing required sub-agent inputs.
# An input is satisfied if it is linked OR has an allowed
# hardcoded value (i.e. equals the schema default — the
# previous check already flags non-default hardcoded values).
hardcoded_inputs_dict = (
hardcoded_inputs if isinstance(hardcoded_inputs, dict) else {}
)
for req_input in sub_agent_required_inputs:
if req_input in linked_sub_agent_inputs:
continue
# Check if fixer populated it with a schema default value
if req_input in hardcoded_inputs_dict:
prop_schema = schema_properties.get(req_input, {})
schema_default = (
prop_schema.get("default")
if isinstance(prop_schema, dict)
else None
)
if schema_default is not None and self._values_equal(
hardcoded_inputs_dict[req_input], schema_default
):
continue
self.add_error(
f"AgentExecutorBlock node '{node_id}' is "
f"missing required sub-agent input "
f"'{req_input}'. Create a link to this node "
f"using sink_name '{req_input}' to connect "
f"the input."
)
valid = False
return valid
def validate_agent_executor_block_schemas(
self,
agent: AgentDict,
) -> bool:
"""
Validate that AgentExecutorBlock nodes have valid input_schema and
output_schema.
This validation runs regardless of library_agents availability and
ensures that the schemas are properly populated to prevent frontend
crashes.
Args:
agent: The agent dictionary to validate
Returns:
True if all AgentExecutorBlock nodes have valid schemas, False
otherwise
"""
valid = True
nodes = agent.get("nodes", [])
for node in nodes:
if node.get("block_id") != AGENT_EXECUTOR_BLOCK_ID:
continue
node_id = node.get("id")
input_default = node.get("input_default", {})
customized_name = (node.get("metadata") or {}).get(
"customized_name", "Unknown"
)
# Check input_schema
input_schema = input_default.get("input_schema")
if input_schema is None or not isinstance(input_schema, dict):
self.add_error(
f"AgentExecutorBlock node '{node_id}' "
f"({customized_name}) has missing or invalid "
f"input_schema. The input_schema must be a valid "
f"JSON Schema object with 'properties' and "
f"'required' fields."
)
valid = False
elif not input_schema.get("properties") and not input_schema.get("type"):
# Empty schema like {} is invalid
self.add_error(
f"AgentExecutorBlock node '{node_id}' "
f"({customized_name}) has empty input_schema. The "
f"input_schema must define the sub-agent's expected "
f"inputs. This usually indicates the sub-agent "
f"reference is incomplete or the library agent was "
f"not properly passed."
)
valid = False
# Check output_schema
output_schema = input_default.get("output_schema")
if output_schema is None or not isinstance(output_schema, dict):
self.add_error(
f"AgentExecutorBlock node '{node_id}' "
f"({customized_name}) has missing or invalid "
f"output_schema. The output_schema must be a valid "
f"JSON Schema object defining the sub-agent's "
f"outputs."
)
valid = False
elif not output_schema.get("properties") and not output_schema.get("type"):
# Empty schema like {} is invalid
self.add_error(
f"AgentExecutorBlock node '{node_id}' "
f"({customized_name}) has empty output_schema. "
f"The output_schema must define the sub-agent's "
f"expected outputs. This usually indicates the "
f"sub-agent reference is incomplete or the library "
f"agent was not properly passed."
)
valid = False
return valid
def validate_mcp_tool_blocks(self, agent: AgentDict) -> bool:
"""Validate that MCPToolBlock nodes have required fields.
Checks that each MCPToolBlock node has:
1. A non-empty `server_url` in input_default
2. A non-empty `selected_tool` in input_default
Returns True if all MCPToolBlock nodes are valid, False otherwise.
"""
valid = True
nodes = agent.get("nodes", [])
for node in nodes:
if node.get("block_id") != MCP_TOOL_BLOCK_ID:
continue
node_id = node.get("id", "unknown")
input_default = node.get("input_default", {})
customized_name = (node.get("metadata") or {}).get(
"customized_name", node_id
)
server_url = input_default.get("server_url")
if not server_url:
self.add_error(
f"MCPToolBlock node '{customized_name}' ({node_id}) is "
f"missing required 'server_url' in input_default. "
f"Set this to the MCP server URL "
f"(e.g. 'https://mcp.example.com/sse')."
)
valid = False
selected_tool = input_default.get("selected_tool")
if not selected_tool:
self.add_error(
f"MCPToolBlock node '{customized_name}' ({node_id}) is "
f"missing required 'selected_tool' in input_default. "
f"Set this to the name of the MCP tool to execute."
)
valid = False
return valid
def validate(
self,
agent: AgentDict,
blocks: list[dict[str, Any]],
library_agents: list[dict[str, Any]] | None = None,
) -> tuple[bool, str | None]:
"""
Comprehensive validation of an agent against available blocks.
Returns:
Tuple[bool, Optional[str]]: (is_valid, error_message)
- is_valid: True if agent passes all validations, False otherwise
- error_message: Detailed error message if validation fails, None
if successful
"""
logger.info("Validating agent...")
self.errors = []
checks = [
(
"Block existence",
self.validate_block_existence(agent, blocks),
),
(
"Link node references",
self.validate_link_node_references(agent),
),
(
"Required inputs",
self.validate_required_inputs(agent, blocks),
),
(
"Data type compatibility",
self.validate_data_type_compatibility(agent, blocks),
),
(
"Nested sink links",
self.validate_nested_sink_links(agent, blocks),
),
(
"Source output existence",
self.validate_source_output_existence(agent, blocks),
),
(
"Prompt double curly braces spaces",
self.validate_prompt_double_curly_braces_spaces(agent),
),
(
"IO blocks",
self.validate_io_blocks(agent),
),
# Always validate AgentExecutorBlock schemas to prevent
# frontend crashes
(
"AgentExecutorBlock schemas",
self.validate_agent_executor_block_schemas(agent),
),
(
"MCP tool blocks",
self.validate_mcp_tool_blocks(agent),
),
]
# Add AgentExecutorBlock detailed validation if library_agents
# provided
if library_agents:
checks.append(
(
"AgentExecutorBlock references",
self.validate_agent_executor_blocks(agent, library_agents),
)
)
all_passed = all(check[1] for check in checks)
if all_passed:
logger.info("Agent validation successful.")
return True, None
else:
error_message = "Agent validation failed with the following errors:\n\n"
for i, error in enumerate(self.errors, 1):
error_message += f"{i}. {error}\n"
logger.error(f"Agent validation failed: {error_message}")
return False, error_message

View File

@@ -0,0 +1,710 @@
"""Unit tests for AgentValidator."""
from .helpers import (
AGENT_EXECUTOR_BLOCK_ID,
AGENT_INPUT_BLOCK_ID,
AGENT_OUTPUT_BLOCK_ID,
MCP_TOOL_BLOCK_ID,
generate_uuid,
)
from .validator import AgentValidator
def _make_agent(
nodes: list | None = None,
links: list | None = None,
agent_id: str | None = None,
) -> dict:
"""Build a minimal agent dict for testing."""
return {
"id": agent_id or generate_uuid(),
"name": "Test Agent",
"nodes": nodes or [],
"links": links or [],
}
def _make_node(
node_id: str | None = None,
block_id: str = "block-1",
input_default: dict | None = None,
position: tuple[int, int] = (0, 0),
) -> dict:
return {
"id": node_id or generate_uuid(),
"block_id": block_id,
"input_default": input_default or {},
"metadata": {"position": {"x": position[0], "y": position[1]}},
}
def _make_link(
link_id: str | None = None,
source_id: str = "",
source_name: str = "output",
sink_id: str = "",
sink_name: str = "input",
) -> dict:
return {
"id": link_id or generate_uuid(),
"source_id": source_id,
"source_name": source_name,
"sink_id": sink_id,
"sink_name": sink_name,
}
def _make_block(
block_id: str = "block-1",
name: str = "TestBlock",
input_schema: dict | None = None,
output_schema: dict | None = None,
categories: list | None = None,
static_output: bool = False,
) -> dict:
return {
"id": block_id,
"name": name,
"inputSchema": input_schema or {"properties": {}, "required": []},
"outputSchema": output_schema or {"properties": {}},
"categories": categories or [],
"staticOutput": static_output,
}
# ============================================================================
# validate_block_existence
# ============================================================================
class TestValidateBlockExistence:
def test_valid_blocks_pass(self):
v = AgentValidator()
node = _make_node(block_id="b1")
block = _make_block(block_id="b1")
agent = _make_agent(nodes=[node])
assert v.validate_block_existence(agent, [block]) is True
assert v.errors == []
def test_missing_block_fails(self):
v = AgentValidator()
node = _make_node(block_id="nonexistent")
agent = _make_agent(nodes=[node])
assert v.validate_block_existence(agent, []) is False
assert len(v.errors) == 1
assert "does not exist" in v.errors[0]
def test_missing_block_id_field(self):
v = AgentValidator()
node = {"id": "n1", "input_default": {}, "metadata": {}}
agent = _make_agent(nodes=[node])
assert v.validate_block_existence(agent, []) is False
assert "missing a 'block_id'" in v.errors[0]
# ============================================================================
# validate_link_node_references
# ============================================================================
class TestValidateLinkNodeReferences:
def test_valid_references_pass(self):
v = AgentValidator()
n1 = _make_node(node_id="n1")
n2 = _make_node(node_id="n2")
link = _make_link(source_id="n1", sink_id="n2")
agent = _make_agent(nodes=[n1, n2], links=[link])
assert v.validate_link_node_references(agent) is True
assert v.errors == []
def test_invalid_source_fails(self):
v = AgentValidator()
n1 = _make_node(node_id="n1")
link = _make_link(source_id="missing", sink_id="n1")
agent = _make_agent(nodes=[n1], links=[link])
assert v.validate_link_node_references(agent) is False
assert any("source_id" in e for e in v.errors)
def test_invalid_sink_fails(self):
v = AgentValidator()
n1 = _make_node(node_id="n1")
link = _make_link(source_id="n1", sink_id="missing")
agent = _make_agent(nodes=[n1], links=[link])
assert v.validate_link_node_references(agent) is False
assert any("sink_id" in e for e in v.errors)
# ============================================================================
# validate_required_inputs
# ============================================================================
class TestValidateRequiredInputs:
def test_satisfied_by_default_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {"url": {"type": "string"}},
"required": ["url"],
},
)
node = _make_node(block_id="b1", input_default={"url": "http://example.com"})
agent = _make_agent(nodes=[node])
assert v.validate_required_inputs(agent, [block]) is True
assert v.errors == []
def test_satisfied_by_link_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {"url": {"type": "string"}},
"required": ["url"],
},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(source_id="n2", sink_id="n1", sink_name="url")
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_required_inputs(agent, [block]) is True
def test_missing_required_input_fails(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {"url": {"type": "string"}},
"required": ["url"],
},
)
node = _make_node(block_id="b1", input_default={})
agent = _make_agent(nodes=[node])
assert v.validate_required_inputs(agent, [block]) is False
assert any("missing required input" in e for e in v.errors)
def test_credentials_always_allowed_missing(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {"credentials": {"type": "object"}},
"required": ["credentials"],
},
)
node = _make_node(block_id="b1", input_default={})
agent = _make_agent(nodes=[node])
assert v.validate_required_inputs(agent, [block]) is True
# ============================================================================
# validate_data_type_compatibility
# ============================================================================
class TestValidateDataTypeCompatibility:
def test_matching_types_pass(self):
v = AgentValidator()
src_block = _make_block(
block_id="src-b",
output_schema={"properties": {"out": {"type": "string"}}},
)
sink_block = _make_block(
block_id="sink-b",
input_schema={"properties": {"inp": {"type": "string"}}, "required": []},
)
src_node = _make_node(node_id="n1", block_id="src-b")
sink_node = _make_node(node_id="n2", block_id="sink-b")
link = _make_link(
source_id="n1", source_name="out", sink_id="n2", sink_name="inp"
)
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
assert (
v.validate_data_type_compatibility(agent, [src_block, sink_block]) is True
)
def test_int_number_compatible(self):
v = AgentValidator()
src_block = _make_block(
block_id="src-b",
output_schema={"properties": {"out": {"type": "integer"}}},
)
sink_block = _make_block(
block_id="sink-b",
input_schema={"properties": {"inp": {"type": "number"}}, "required": []},
)
src_node = _make_node(node_id="n1", block_id="src-b")
sink_node = _make_node(node_id="n2", block_id="sink-b")
link = _make_link(
source_id="n1", source_name="out", sink_id="n2", sink_name="inp"
)
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
assert (
v.validate_data_type_compatibility(agent, [src_block, sink_block]) is True
)
def test_mismatched_types_fail(self):
v = AgentValidator()
src_block = _make_block(
block_id="src-b",
output_schema={"properties": {"out": {"type": "string"}}},
)
sink_block = _make_block(
block_id="sink-b",
input_schema={"properties": {"inp": {"type": "integer"}}, "required": []},
)
src_node = _make_node(node_id="n1", block_id="src-b")
sink_node = _make_node(node_id="n2", block_id="sink-b")
link = _make_link(
source_id="n1", source_name="out", sink_id="n2", sink_name="inp"
)
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
assert (
v.validate_data_type_compatibility(agent, [src_block, sink_block]) is False
)
assert any("mismatch" in e.lower() for e in v.errors)
# ============================================================================
# validate_source_output_existence
# ============================================================================
class TestValidateSourceOutputExistence:
def test_valid_source_output_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
output_schema={"properties": {"result": {"type": "string"}}},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(source_id="n1", source_name="result", sink_id="n2")
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_source_output_existence(agent, [block]) is True
def test_invalid_source_output_fails(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
output_schema={"properties": {"result": {"type": "string"}}},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(source_id="n1", source_name="nonexistent", sink_id="n2")
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_source_output_existence(agent, [block]) is False
assert any("does not exist" in e for e in v.errors)
# ============================================================================
# validate_prompt_double_curly_braces_spaces
# ============================================================================
class TestValidatePromptDoubleCurlyBracesSpaces:
def test_no_spaces_passes(self):
v = AgentValidator()
node = _make_node(input_default={"prompt": "Hello {{name}}!"})
agent = _make_agent(nodes=[node])
assert v.validate_prompt_double_curly_braces_spaces(agent) is True
def test_spaces_in_braces_fails(self):
v = AgentValidator()
node = _make_node(input_default={"prompt": "Hello {{user name}}!"})
agent = _make_agent(nodes=[node])
assert v.validate_prompt_double_curly_braces_spaces(agent) is False
assert any("spaces" in e for e in v.errors)
# ============================================================================
# validate_nested_sink_links
# ============================================================================
class TestValidateNestedSinkLinks:
def test_valid_nested_link_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {
"config": {
"type": "object",
"properties": {"key": {"type": "string"}},
}
},
"required": [],
},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(sink_id="n1", sink_name="config_#_key", source_id="n2")
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_nested_sink_links(agent, [block]) is True
def test_invalid_parent_fails(self):
v = AgentValidator()
block = _make_block(block_id="b1")
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(sink_id="n1", sink_name="nonexistent_#_key", source_id="n2")
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_nested_sink_links(agent, [block]) is False
assert any("does not exist" in e for e in v.errors)
# ============================================================================
# validate_agent_executor_block_schemas
# ============================================================================
class TestValidateAgentExecutorBlockSchemas:
def test_valid_schemas_pass(self):
v = AgentValidator()
node = _make_node(
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": generate_uuid(),
"input_schema": {"properties": {"q": {"type": "string"}}},
"output_schema": {"properties": {"result": {"type": "string"}}},
},
)
agent = _make_agent(nodes=[node])
assert v.validate_agent_executor_block_schemas(agent) is True
assert v.errors == []
def test_empty_input_schema_fails(self):
v = AgentValidator()
node = _make_node(
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": generate_uuid(),
"input_schema": {},
"output_schema": {"properties": {"result": {"type": "string"}}},
},
)
agent = _make_agent(nodes=[node])
assert v.validate_agent_executor_block_schemas(agent) is False
assert any("empty input_schema" in e for e in v.errors)
def test_missing_output_schema_fails(self):
v = AgentValidator()
node = _make_node(
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": generate_uuid(),
"input_schema": {"properties": {"q": {"type": "string"}}},
},
)
agent = _make_agent(nodes=[node])
assert v.validate_agent_executor_block_schemas(agent) is False
assert any("output_schema" in e for e in v.errors)
# ============================================================================
# validate_agent_executor_blocks
# ============================================================================
class TestValidateAgentExecutorBlocks:
def test_missing_graph_id_fails(self):
v = AgentValidator()
node = _make_node(
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={},
)
agent = _make_agent(nodes=[node])
assert v.validate_agent_executor_blocks(agent) is False
assert any("graph_id" in e for e in v.errors)
def test_valid_graph_id_passes(self):
v = AgentValidator()
node = _make_node(
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={"graph_id": generate_uuid()},
)
agent = _make_agent(nodes=[node])
assert v.validate_agent_executor_blocks(agent) is True
def test_version_mismatch_with_library_agent(self):
v = AgentValidator()
lib_id = generate_uuid()
node = _make_node(
node_id="n1",
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={"graph_id": lib_id, "graph_version": 1},
)
agent = _make_agent(nodes=[node])
library_agents = [{"graph_id": lib_id, "graph_version": 3, "name": "Sub Agent"}]
assert v.validate_agent_executor_blocks(agent, library_agents) is False
assert any("mismatched graph_version" in e for e in v.errors)
def test_required_input_satisfied_by_schema_default_passes(self):
"""Required sub-agent inputs filled with their schema default by the fixer
should NOT be flagged as missing."""
v = AgentValidator()
lib_id = generate_uuid()
node = _make_node(
node_id="n1",
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": lib_id,
"input_schema": {
"properties": {"mode": {"type": "string", "default": "fast"}}
},
"inputs": {"mode": "fast"}, # fixer populated with schema default
},
)
agent = _make_agent(nodes=[node])
library_agents = [
{
"graph_id": lib_id,
"graph_version": 1,
"name": "Sub",
"input_schema": {
"required": ["mode"],
"properties": {"mode": {"type": "string", "default": "fast"}},
},
"output_schema": {},
}
]
assert v.validate_agent_executor_blocks(agent, library_agents) is True
assert v.errors == []
def test_required_input_not_linked_and_no_default_fails(self):
"""Required sub-agent inputs without a link or schema default must fail."""
v = AgentValidator()
lib_id = generate_uuid()
node = _make_node(
node_id="n1",
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": lib_id,
"input_schema": {"properties": {"query": {"type": "string"}}},
"inputs": {},
},
)
agent = _make_agent(nodes=[node])
library_agents = [
{
"graph_id": lib_id,
"graph_version": 1,
"name": "Sub",
"input_schema": {
"required": ["query"],
"properties": {"query": {"type": "string"}},
},
"output_schema": {},
}
]
assert v.validate_agent_executor_blocks(agent, library_agents) is False
assert any("missing required sub-agent input" in e for e in v.errors)
# ============================================================================
# validate_io_blocks
# ============================================================================
class TestValidateIoBlocks:
def test_missing_input_block_reports_error(self):
v = AgentValidator()
# Agent has output block but no input block
node = _make_node(block_id=AGENT_OUTPUT_BLOCK_ID)
agent = _make_agent(nodes=[node])
assert v.validate_io_blocks(agent) is False
assert len(v.errors) == 1
assert "AgentInputBlock" in v.errors[0]
def test_missing_output_block_reports_error(self):
v = AgentValidator()
# Agent has input block but no output block
node = _make_node(block_id=AGENT_INPUT_BLOCK_ID)
agent = _make_agent(nodes=[node])
assert v.validate_io_blocks(agent) is False
assert len(v.errors) == 1
assert "AgentOutputBlock" in v.errors[0]
def test_missing_both_io_blocks_reports_two_errors(self):
v = AgentValidator()
node = _make_node(block_id="some-other-block")
agent = _make_agent(nodes=[node])
assert v.validate_io_blocks(agent) is False
assert len(v.errors) == 2
def test_both_io_blocks_present_no_error(self):
v = AgentValidator()
input_node = _make_node(block_id=AGENT_INPUT_BLOCK_ID)
output_node = _make_node(block_id=AGENT_OUTPUT_BLOCK_ID)
agent = _make_agent(nodes=[input_node, output_node])
assert v.validate_io_blocks(agent) is True
assert v.errors == []
def test_empty_agent_reports_both_missing(self):
v = AgentValidator()
agent = _make_agent(nodes=[])
assert v.validate_io_blocks(agent) is False
assert len(v.errors) == 2
# ============================================================================
# validate (integration)
# ============================================================================
class TestValidate:
def test_valid_agent_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {"url": {"type": "string"}},
"required": ["url"],
},
output_schema={"properties": {"result": {"type": "string"}}},
)
input_block = _make_block(
block_id=AGENT_INPUT_BLOCK_ID,
name="AgentInputBlock",
output_schema={"properties": {"result": {}}},
)
output_block = _make_block(
block_id=AGENT_OUTPUT_BLOCK_ID,
name="AgentOutputBlock",
)
input_node = _make_node(
node_id="n-in",
block_id=AGENT_INPUT_BLOCK_ID,
input_default={"name": "url"},
)
n1 = _make_node(
node_id="n1", block_id="b1", input_default={"url": "http://example.com"}
)
n2 = _make_node(
node_id="n2", block_id="b1", input_default={"url": "http://example2.com"}
)
output_node = _make_node(
node_id="n-out",
block_id=AGENT_OUTPUT_BLOCK_ID,
input_default={"name": "result"},
)
link = _make_link(
source_id="n1", source_name="result", sink_id="n2", sink_name="url"
)
agent = _make_agent(nodes=[input_node, n1, n2, output_node], links=[link])
is_valid, error_message = v.validate(agent, [block, input_block, output_block])
assert is_valid is True
assert error_message is None
def test_invalid_agent_returns_errors(self):
v = AgentValidator()
node = _make_node(block_id="nonexistent")
agent = _make_agent(nodes=[node])
is_valid, error_message = v.validate(agent, [])
assert is_valid is False
assert error_message is not None
assert "does not exist" in error_message
def test_empty_agent_fails_io_validation(self):
v = AgentValidator()
agent = _make_agent()
is_valid, error_message = v.validate(agent, [])
assert is_valid is False
assert error_message is not None
assert "AgentInputBlock" in error_message
assert "AgentOutputBlock" in error_message
class TestValidateMCPToolBlocks:
"""Tests for validate_mcp_tool_blocks."""
def test_missing_server_url_reports_error(self):
v = AgentValidator()
node = _make_node(
block_id=MCP_TOOL_BLOCK_ID,
input_default={"selected_tool": "my_tool"},
)
agent = _make_agent(nodes=[node])
result = v.validate_mcp_tool_blocks(agent)
assert result is False
assert any("server_url" in e for e in v.errors)
def test_missing_selected_tool_reports_error(self):
v = AgentValidator()
node = _make_node(
block_id=MCP_TOOL_BLOCK_ID,
input_default={"server_url": "https://mcp.example.com/sse"},
)
agent = _make_agent(nodes=[node])
result = v.validate_mcp_tool_blocks(agent)
assert result is False
assert any("selected_tool" in e for e in v.errors)
def test_valid_mcp_block_passes(self):
v = AgentValidator()
node = _make_node(
block_id=MCP_TOOL_BLOCK_ID,
input_default={
"server_url": "https://mcp.example.com/sse",
"selected_tool": "search",
"tool_input_schema": {"properties": {"query": {"type": "string"}}},
"tool_arguments": {},
},
)
agent = _make_agent(nodes=[node])
result = v.validate_mcp_tool_blocks(agent)
assert result is True
assert len(v.errors) == 0
def test_both_missing_reports_two_errors(self):
v = AgentValidator()
node = _make_node(
block_id=MCP_TOOL_BLOCK_ID,
input_default={},
)
agent = _make_agent(nodes=[node])
v.validate_mcp_tool_blocks(agent)
assert len(v.errors) == 2

View File

@@ -208,6 +208,9 @@ def _library_agent_to_info(agent: LibraryAgent) -> AgentInfo:
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
graph_version=agent.graph_version,
input_schema=agent.input_schema,
output_schema=agent.output_schema,
)

View File

@@ -21,10 +21,10 @@ from typing import Any
from e2b import AsyncSandbox
from e2b.exceptions import TimeoutException
from backend.copilot.context import E2B_WORKDIR, get_current_sandbox
from backend.copilot.model import ChatSession
from .base import BaseTool
from .e2b_sandbox import E2B_WORKDIR
from .models import BashExecResponse, ErrorResponse, ToolResponseBase
from .sandbox import get_workspace_dir, has_full_sandbox, run_sandboxed
@@ -94,9 +94,6 @@ class BashExecTool(BaseTool):
session_id=session_id,
)
# E2B path: run on remote cloud sandbox when available.
from backend.copilot.sdk.tool_adapter import get_current_sandbox
sandbox = get_current_sandbox()
if sandbox is not None:
return await self._execute_on_e2b(sandbox, command, timeout, session_id)

View File

@@ -1,34 +1,20 @@
"""CreateAgentTool - Creates agents from natural language descriptions."""
"""CreateAgentTool - Creates agents from pre-built JSON."""
import logging
import uuid
from typing import Any
from backend.copilot.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,
decompose_goal,
enrich_library_agents_from_steps,
generate_agent,
get_user_message_for_error,
save_agent_to_library,
)
from .agent_generator.pipeline import fetch_library_agents, fix_validate_and_save
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
SuggestedGoalResponse,
ToolResponseBase,
)
from .models import ErrorResponse, ToolResponseBase
logger = logging.getLogger(__name__)
class CreateAgentTool(BaseTool):
"""Tool for creating agents from natural language descriptions."""
"""Tool for creating agents from pre-built JSON."""
@property
def name(self) -> str:
@@ -37,15 +23,12 @@ class CreateAgentTool(BaseTool):
@property
def description(self) -> str:
return (
"Create a new agent workflow from a natural language description. "
"First generates a preview, then saves to library if save=true. "
"\n\nWorkflow: (1) Always check find_library_agent first for existing building blocks. "
"(2) Call create_agent with description and library_agent_ids. "
"(3) If response contains suggested_goal: Present to user, ask for confirmation, "
"then call again with the suggested goal if accepted. "
"(4) If response contains clarifying_questions: Present to user, collect answers, "
"then call again with original description AND answers in the context parameter. "
"\n\nThis feedback loop ensures the generated agent matches user intent."
"Create a new agent workflow. Pass `agent_json` with the complete "
"agent graph JSON you generated using block schemas from find_block. "
"The tool validates, auto-fixes, and saves.\n\n"
"IMPORTANT: Before calling this tool, search for relevant existing agents "
"using find_library_agent that could be used as building blocks. "
"Pass their IDs in the library_agent_ids parameter."
)
@property
@@ -57,34 +40,26 @@ class CreateAgentTool(BaseTool):
return {
"type": "object",
"properties": {
"description": {
"type": "string",
"agent_json": {
"type": "object",
"description": (
"Natural language description of what the agent should do. "
"Be specific about inputs, outputs, and the workflow steps."
),
},
"context": {
"type": "string",
"description": (
"Additional context or answers to previous clarifying questions. "
"Include any preferences or constraints mentioned by the user."
"The agent JSON to validate and save. "
"Must contain 'nodes' and 'links' arrays, and optionally "
"'name' and 'description'."
),
},
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of library agent IDs to use as building blocks. "
"Search for relevant agents using find_library_agent first, "
"then pass their IDs here so they can be composed into the new agent."
"List of library agent IDs to use as building blocks."
),
},
"save": {
"type": "boolean",
"description": (
"Whether to save the agent to the user's library. "
"Default is true. Set to false for preview only."
"Whether to save the agent. Default is true. "
"Set to false for preview only."
),
"default": True,
},
@@ -97,7 +72,7 @@ class CreateAgentTool(BaseTool):
),
},
},
"required": ["description"],
"required": ["agent_json"],
}
async def _execute(
@@ -106,278 +81,49 @@ class CreateAgentTool(BaseTool):
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the create_agent tool.
Flow:
1. Decompose the description into steps (may return clarifying questions)
2. Generate agent JSON (external service handles fixing and validation)
3. Preview or save based on the save parameter
"""
description = kwargs.get("description", "").strip()
context = kwargs.get("context", "")
library_agent_ids = kwargs.get("library_agent_ids", [])
save = kwargs.get("save", True)
folder_id = kwargs.get("folder_id")
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
session_id = session.session_id if session else None
logger.info(
f"[AGENT_CREATE_DEBUG] START - description_len={len(description)}, "
f"library_agent_ids={library_agent_ids}, save={save}, user_id={user_id}, session_id={session_id}"
if not agent_json:
return ErrorResponse(
message=(
"Please provide agent_json with the complete agent graph. "
"Use find_block to discover blocks, then generate the JSON."
),
error="missing_agent_json",
session_id=session_id,
)
save = kwargs.get("save", True)
library_agent_ids = kwargs.get("library_agent_ids", [])
folder_id: str | None = kwargs.get("folder_id")
nodes = agent_json.get("nodes", [])
if not nodes:
return ErrorResponse(
message="The agent JSON has no nodes. An agent needs at least one block.",
error="empty_agent",
session_id=session_id,
)
# Ensure top-level fields
if "id" not in agent_json:
agent_json["id"] = str(uuid.uuid4())
if "version" not in agent_json:
agent_json["version"] = 1
if "is_active" not in agent_json:
agent_json["is_active"] = True
# Fetch library agents for AgentExecutorBlock validation
library_agents = await fetch_library_agents(user_id, library_agent_ids)
return await fix_validate_and_save(
agent_json,
user_id=user_id,
session_id=session_id,
save=save,
is_update=False,
default_name="Generated Agent",
library_agents=library_agents,
folder_id=folder_id,
)
if not description:
return ErrorResponse(
message="Please provide a description of what the agent should do.",
error="Missing description parameter",
session_id=session_id,
)
# Fetch library agents by IDs if provided
library_agents = None
if user_id and library_agent_ids:
try:
from .agent_generator import get_library_agents_by_ids
library_agents = await get_library_agents_by_ids(
user_id=user_id,
agent_ids=library_agent_ids,
)
logger.debug(
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents by IDs: {e}")
try:
decomposition_result = await decompose_goal(
description, context, library_agents
)
logger.info(
f"[AGENT_CREATE_DEBUG] DECOMPOSE - type={decomposition_result.get('type') if decomposition_result else None}, "
f"session_id={session_id}"
)
except AgentGeneratorNotConfiguredError:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured, session_id={session_id}"
)
return ErrorResponse(
message=(
"Agent generation is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
if decomposition_result is None:
return ErrorResponse(
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
error="decomposition_failed",
details={"description": description[:100]},
session_id=session_id,
)
if decomposition_result.get("type") == "error":
error_msg = decomposition_result.get("error", "Unknown error")
error_type = decomposition_result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="analyze the goal",
llm_parse_message="The AI had trouble understanding this request. Please try rephrasing your goal.",
)
return ErrorResponse(
message=user_message,
error=f"decomposition_failed:{error_type}",
details={
"description": description[:100],
"service_error": error_msg,
"error_type": error_type,
},
session_id=session_id,
)
if decomposition_result.get("type") == "clarifying_questions":
questions = decomposition_result.get("questions", [])
return ClarificationNeededResponse(
message=(
"I need some more information to create this agent. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
],
session_id=session_id,
)
if decomposition_result.get("type") == "unachievable_goal":
suggested = decomposition_result.get("suggested_goal", "")
reason = decomposition_result.get("reason", "")
return SuggestedGoalResponse(
message=(
f"This goal cannot be accomplished with the available blocks. {reason}"
),
suggested_goal=suggested,
reason=reason,
original_goal=description,
goal_type="unachievable",
session_id=session_id,
)
if decomposition_result.get("type") == "vague_goal":
suggested = decomposition_result.get("suggested_goal", "")
reason = decomposition_result.get(
"reason", "The goal needs more specific details"
)
return SuggestedGoalResponse(
message="The goal is too vague to create a specific workflow.",
suggested_goal=suggested,
reason=reason,
original_goal=description,
goal_type="vague",
session_id=session_id,
)
if user_id and library_agents is not None:
try:
library_agents = await enrich_library_agents_from_steps(
user_id=user_id,
decomposition_result=decomposition_result,
existing_agents=library_agents,
include_marketplace=True,
)
logger.debug(
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to enrich library agents from steps: {e}")
try:
agent_json = await generate_agent(
decomposition_result,
library_agents,
)
logger.info(
f"[AGENT_CREATE_DEBUG] GENERATE - "
f"success={agent_json is not None}, "
f"is_error={isinstance(agent_json, dict) and agent_json.get('type') == 'error'}, "
f"session_id={session_id}"
)
except AgentGeneratorNotConfiguredError:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured during generation, session_id={session_id}"
)
return ErrorResponse(
message=(
"Agent generation is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
if agent_json is None:
return ErrorResponse(
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
error="generation_failed",
details={"description": description[:100]},
session_id=session_id,
)
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
error_msg = agent_json.get("error", "Unknown error")
error_type = agent_json.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="generate the agent",
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
validation_message=(
"I wasn't able to create a valid agent for this request. "
"The generated workflow had some structural issues. "
"Please try simplifying your goal or breaking it into smaller steps."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"generation_failed:{error_type}",
details={
"description": description[:100],
"service_error": error_msg,
"error_type": error_type,
},
session_id=session_id,
)
agent_name = agent_json.get("name", "Generated Agent")
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
logger.info(
f"[AGENT_CREATE_DEBUG] AGENT_JSON - name={agent_name}, "
f"nodes={node_count}, links={link_count}, save={save}, session_id={session_id}"
)
if not save:
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - AgentPreviewResponse, session_id={session_id}"
)
return AgentPreviewResponse(
message=(
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
f"Review it and call create_agent with save=true to save it to your library."
),
agent_json=agent_json,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
try:
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id, folder_id=folder_id
)
logger.info(
f"[AGENT_CREATE_DEBUG] SAVED - graph_id={created_graph.id}, "
f"library_agent_id={library_agent.id}, session_id={session_id}"
)
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - AgentSavedResponse, session_id={session_id}"
)
return AgentSavedResponse(
message=f"Agent '{created_graph.name}' has been saved to your library!",
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - save_failed: {str(e)}, session_id={session_id}"
)
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - ErrorResponse (save_failed), session_id={session_id}"
)
return ErrorResponse(
message=f"Failed to save the agent: {str(e)}",
error="save_failed",
details={"exception": str(e)},
session_id=session_id,
)

View File

@@ -1,19 +1,16 @@
"""Tests for CreateAgentTool response types."""
"""Tests for CreateAgentTool."""
from unittest.mock import AsyncMock, patch
from unittest.mock import MagicMock, patch
import pytest
from backend.copilot.tools.create_agent import CreateAgentTool
from backend.copilot.tools.models import (
ClarificationNeededResponse,
ErrorResponse,
SuggestedGoalResponse,
)
from backend.copilot.tools.models import AgentPreviewResponse, ErrorResponse
from ._test_data import make_session
_TEST_USER_ID = "test-user-create-agent"
_PIPELINE = "backend.copilot.tools.agent_generator.pipeline"
@pytest.fixture
@@ -26,102 +23,147 @@ def session():
return make_session(_TEST_USER_ID)
# ── Input validation tests ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_missing_description_returns_error(tool, session):
"""Missing description returns ErrorResponse."""
result = await tool._execute(user_id=_TEST_USER_ID, session=session, description="")
async def test_missing_agent_json_returns_error(tool, session):
"""Missing agent_json returns ErrorResponse."""
result = await tool._execute(user_id=_TEST_USER_ID, session=session)
assert isinstance(result, ErrorResponse)
assert result.error == "Missing description parameter"
assert result.error == "missing_agent_json"
# ── Local mode tests ────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_vague_goal_returns_suggested_goal_response(tool, session):
"""vague_goal decomposition result returns SuggestedGoalResponse, not ErrorResponse."""
vague_result = {
"type": "vague_goal",
"suggested_goal": "Monitor Twitter mentions for a specific keyword and send a daily digest email",
}
with (
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,
return_value=vague_result,
),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
description="monitor social media",
)
assert isinstance(result, SuggestedGoalResponse)
assert result.goal_type == "vague"
assert result.suggested_goal == vague_result["suggested_goal"]
assert result.original_goal == "monitor social media"
assert result.reason == "The goal needs more specific details"
assert not isinstance(result, ErrorResponse)
async def test_local_mode_empty_nodes_returns_error(tool, session):
"""Local mode with no nodes returns ErrorResponse."""
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json={"nodes": [], "links": []},
)
assert isinstance(result, ErrorResponse)
assert "no nodes" in result.message.lower()
@pytest.mark.asyncio
async def test_unachievable_goal_returns_suggested_goal_response(tool, session):
"""unachievable_goal decomposition result returns SuggestedGoalResponse, not ErrorResponse."""
unachievable_result = {
"type": "unachievable_goal",
"suggested_goal": "Summarize the latest news articles on a topic and send them by email",
"reason": "There are no blocks for mind-reading.",
}
with (
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,
return_value=unachievable_result,
),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
description="read my mind",
)
assert isinstance(result, SuggestedGoalResponse)
assert result.goal_type == "unachievable"
assert result.suggested_goal == unachievable_result["suggested_goal"]
assert result.original_goal == "read my mind"
assert result.reason == unachievable_result["reason"]
assert not isinstance(result, ErrorResponse)
@pytest.mark.asyncio
async def test_clarifying_questions_returns_clarification_needed_response(
tool, session
):
"""clarifying_questions decomposition result returns ClarificationNeededResponse."""
clarifying_result = {
"type": "clarifying_questions",
"questions": [
async def test_local_mode_preview(tool, session):
"""Local mode with save=false returns AgentPreviewResponse."""
agent_json = {
"name": "Test Agent",
"description": "A test agent",
"nodes": [
{
"question": "What platform should be monitored?",
"keyword": "platform",
"example": "Twitter, Reddit",
"id": "node-1",
"block_id": "block-1",
"input_default": {},
"metadata": {"position": {"x": 0, "y": 0}},
}
],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
mock_fixer.get_fixes_applied.return_value = []
mock_validator = MagicMock()
mock_validator.validate.return_value = (True, None)
mock_validator.errors = []
with (
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,
return_value=clarifying_result,
),
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
description="monitor social media and alert me",
agent_json=agent_json,
save=False,
)
assert isinstance(result, ClarificationNeededResponse)
assert len(result.questions) == 1
assert result.questions[0].keyword == "platform"
assert isinstance(result, AgentPreviewResponse)
assert result.agent_name == "Test Agent"
assert result.node_count == 1
@pytest.mark.asyncio
async def test_local_mode_validation_failure(tool, session):
"""Local mode returns ErrorResponse when validation fails after fixing."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "bad-block",
"input_default": {},
"metadata": {},
}
],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
mock_fixer.get_fixes_applied.return_value = []
mock_validator = MagicMock()
mock_validator.validate.return_value = (False, "Block 'bad-block' not found")
mock_validator.errors = ["Block 'bad-block' not found"]
with (
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
)
assert isinstance(result, ErrorResponse)
assert result.error == "validation_failed"
assert "Block 'bad-block' not found" in result.message
@pytest.mark.asyncio
async def test_local_mode_no_auth_returns_error(tool, session):
"""Local mode with save=true and no user returns ErrorResponse."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "block-1",
"input_default": {},
"metadata": {},
}
],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
mock_fixer.get_fixes_applied.return_value = []
mock_validator = MagicMock()
mock_validator.validate.return_value = (True, None)
mock_validator.errors = []
with (
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
):
result = await tool._execute(
user_id=None,
session=session,
agent_json=agent_json,
save=True,
)
assert isinstance(result, ErrorResponse)
assert "logged in" in result.message.lower()

View File

@@ -1,34 +1,20 @@
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
"""CustomizeAgentTool - Customizes marketplace/template agents."""
import logging
import uuid
from typing import Any
from backend.copilot.model import ChatSession
from backend.data.db_accessors import store_db as get_store_db
from backend.util.exceptions import NotFoundError
from .agent_generator import (
AgentGeneratorNotConfiguredError,
customize_template,
get_user_message_for_error,
graph_to_json,
save_agent_to_library,
)
from .agent_generator.pipeline import fetch_library_agents, fix_validate_and_save
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
ToolResponseBase,
)
from .models import ErrorResponse, ToolResponseBase
logger = logging.getLogger(__name__)
class CustomizeAgentTool(BaseTool):
"""Tool for customizing marketplace/template agents using natural language."""
"""Tool for customizing marketplace/template agents."""
@property
def name(self) -> str:
@@ -37,9 +23,9 @@ class CustomizeAgentTool(BaseTool):
@property
def description(self) -> str:
return (
"Customize a marketplace or template agent using natural language. "
"Takes an existing agent from the marketplace and modifies it based on "
"the user's requirements before adding to their library."
"Customize a marketplace or template agent. Pass `agent_json` "
"with the complete customized agent JSON. The tool validates, "
"auto-fixes, and saves."
)
@property
@@ -51,32 +37,24 @@ class CustomizeAgentTool(BaseTool):
return {
"type": "object",
"properties": {
"agent_id": {
"type": "string",
"agent_json": {
"type": "object",
"description": (
"The marketplace agent ID in format 'creator/slug' "
"(e.g., 'autogpt/newsletter-writer'). "
"Get this from find_agent results."
"Complete customized agent JSON to validate and save. "
"Optionally include 'name' and 'description'."
),
},
"modifications": {
"type": "string",
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"Natural language description of how to customize the agent. "
"Be specific about what changes you want to make."
),
},
"context": {
"type": "string",
"description": (
"Additional context or answers to previous clarifying questions."
"List of library agent IDs to use as building blocks."
),
},
"save": {
"type": "boolean",
"description": (
"Whether to save the customized agent to the user's library. "
"Default is true. Set to false for preview only."
"Whether to save the customized agent. Default is true."
),
"default": True,
},
@@ -89,7 +67,7 @@ class CustomizeAgentTool(BaseTool):
),
},
},
"required": ["agent_id", "modifications"],
"required": ["agent_json"],
}
async def _execute(
@@ -98,247 +76,46 @@ class CustomizeAgentTool(BaseTool):
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the customize_agent tool.
Flow:
1. Parse the agent ID to get creator/slug
2. Fetch the template agent from the marketplace
3. Call customize_template with the modification request
4. Preview or save based on the save parameter
"""
agent_id = kwargs.get("agent_id", "").strip()
modifications = kwargs.get("modifications", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
folder_id = kwargs.get("folder_id")
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
session_id = session.session_id if session else None
if not agent_id:
return ErrorResponse(
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
error="missing_agent_id",
session_id=session_id,
)
if not modifications:
return ErrorResponse(
message="Please describe how you want to customize this agent.",
error="missing_modifications",
session_id=session_id,
)
# Parse agent_id in format "creator/slug"
parts = [p.strip() for p in agent_id.split("/")]
if len(parts) != 2 or not parts[0] or not parts[1]:
if not agent_json:
return ErrorResponse(
message=(
f"Invalid agent ID format: '{agent_id}'. "
"Expected format is 'creator/agent-name' "
"(e.g., 'autogpt/newsletter-writer')."
"Please provide agent_json with the complete customized agent graph."
),
error="invalid_agent_id_format",
error="missing_agent_json",
session_id=session_id,
)
creator_username, agent_slug = parts
save = kwargs.get("save", True)
library_agent_ids = kwargs.get("library_agent_ids", [])
folder_id: str | None = kwargs.get("folder_id")
store_db = get_store_db()
# Fetch the marketplace agent details
try:
agent_details = await store_db.get_store_agent_details(
username=creator_username, agent_name=agent_slug
)
except NotFoundError:
nodes = agent_json.get("nodes", [])
if not nodes:
return ErrorResponse(
message=(
f"Could not find marketplace agent '{agent_id}'. "
"Please check the agent ID and try again."
),
error="agent_not_found",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the marketplace agent. Please try again.",
error="fetch_error",
message="The agent JSON has no nodes.",
error="empty_agent",
session_id=session_id,
)
if not agent_details.store_listing_version_id:
return ErrorResponse(
message=(
f"The agent '{agent_id}' does not have an available version. "
"Please try a different agent."
),
error="no_version_available",
session_id=session_id,
)
# Ensure top-level fields before the fixer pipeline
if "id" not in agent_json:
agent_json["id"] = str(uuid.uuid4())
agent_json.setdefault("version", 1)
agent_json.setdefault("is_active", True)
# Get the full agent graph
try:
graph = await store_db.get_agent(agent_details.store_listing_version_id)
template_agent = graph_to_json(graph)
except Exception as e:
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the agent configuration. Please try again.",
error="graph_fetch_error",
session_id=session_id,
)
# Fetch library agents for AgentExecutorBlock validation
library_agents = await fetch_library_agents(user_id, library_agent_ids)
# Call customize_template
try:
result = await customize_template(
template_agent=template_agent,
modification_request=modifications,
context=context,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
"Agent customization is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error calling customize_template for {agent_id}: {e}")
return ErrorResponse(
message=(
"Failed to customize the agent due to a service error. "
"Please try again."
),
error="customization_service_error",
session_id=session_id,
)
if result is None:
return ErrorResponse(
message=(
"Failed to customize the agent. "
"The agent generation service may be unavailable or timed out. "
"Please try again."
),
error="customization_failed",
session_id=session_id,
)
# Handle error response
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="customize the agent",
llm_parse_message=(
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
# Handle clarifying questions
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
questions = result.get("questions") or []
if not isinstance(questions, list):
logger.error(
f"Unexpected clarifying questions format: {type(questions)}"
)
questions = []
return ClarificationNeededResponse(
message=(
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
if isinstance(q, dict)
],
session_id=session_id,
)
# Result should be the customized agent JSON
if not isinstance(result, dict):
logger.error(f"Unexpected customize_template response type: {type(result)}")
return ErrorResponse(
message="Failed to customize the agent due to an unexpected response.",
error="unexpected_response_type",
session_id=session_id,
)
customized_agent = result
agent_name = customized_agent.get(
"name", f"Customized {agent_details.agent_name}"
return await fix_validate_and_save(
agent_json,
user_id=user_id,
session_id=session_id,
save=save,
is_update=False,
default_name="Customized Agent",
library_agents=library_agents,
folder_id=folder_id,
)
agent_description = customized_agent.get("description", "")
nodes = customized_agent.get("nodes")
links = customized_agent.get("links")
node_count = len(nodes) if isinstance(nodes, list) else 0
link_count = len(links) if isinstance(links, list) else 0
if not save:
return AgentPreviewResponse(
message=(
f"I've customized the agent '{agent_details.agent_name}'. "
f"The customized agent has {node_count} blocks. "
f"Review it and call customize_agent with save=true to save it."
),
agent_json=customized_agent,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
# Save to user's library
try:
created_graph, library_agent = await save_agent_to_library(
customized_agent, user_id, is_update=False, folder_id=folder_id
)
return AgentSavedResponse(
message=(
f"Customized agent '{created_graph.name}' "
f"(based on '{agent_details.agent_name}') "
f"has been saved to your library!"
),
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error saving customized agent: {e}")
return ErrorResponse(
message="Failed to save the customized agent. Please try again.",
error="save_failed",
session_id=session_id,
)

View File

@@ -0,0 +1,172 @@
"""Tests for CustomizeAgentTool local mode."""
from unittest.mock import MagicMock, patch
import pytest
from backend.copilot.tools.customize_agent import CustomizeAgentTool
from backend.copilot.tools.models import AgentPreviewResponse, ErrorResponse
from ._test_data import make_session
_TEST_USER_ID = "test-user-customize-agent"
_PIPELINE = "backend.copilot.tools.agent_generator.pipeline"
@pytest.fixture
def tool():
return CustomizeAgentTool()
@pytest.fixture
def session():
return make_session(_TEST_USER_ID)
# ── Input validation tests ───────────────────────────────────────────────
@pytest.mark.asyncio
async def test_missing_agent_json_returns_error(tool, session):
"""Missing agent_json returns ErrorResponse."""
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
)
assert isinstance(result, ErrorResponse)
assert result.error == "missing_agent_json"
# ── Local mode tests (agent_json provided) ───────────────────────────────
@pytest.mark.asyncio
async def test_local_mode_empty_nodes_returns_error(tool, session):
"""Local mode with no nodes returns ErrorResponse."""
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json={"nodes": [], "links": []},
)
assert isinstance(result, ErrorResponse)
assert "no nodes" in result.message.lower()
@pytest.mark.asyncio
async def test_local_mode_preview(tool, session):
"""Local mode with save=false returns AgentPreviewResponse."""
agent_json = {
"name": "Customized Agent",
"description": "A customized agent",
"nodes": [
{
"id": "node-1",
"block_id": "block-1",
"input_default": {},
"metadata": {"position": {"x": 0, "y": 0}},
}
],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
mock_fixer.get_fixes_applied.return_value = []
mock_validator = MagicMock()
mock_validator.validate.return_value = (True, None)
mock_validator.errors = []
with (
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
save=False,
)
assert isinstance(result, AgentPreviewResponse)
assert result.agent_name == "Customized Agent"
assert result.node_count == 1
@pytest.mark.asyncio
async def test_local_mode_validation_failure(tool, session):
"""Local mode returns ErrorResponse when validation fails."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "bad-block",
"input_default": {},
"metadata": {},
}
],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
mock_fixer.get_fixes_applied.return_value = []
mock_validator = MagicMock()
mock_validator.validate.return_value = (False, "Block 'bad-block' not found")
mock_validator.errors = ["Block 'bad-block' not found"]
with (
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
)
assert isinstance(result, ErrorResponse)
assert result.error == "validation_failed"
assert "Block 'bad-block' not found" in result.message
@pytest.mark.asyncio
async def test_local_mode_no_auth_returns_error(tool, session):
"""Local mode with save=true and no user returns ErrorResponse."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "block-1",
"input_default": {},
"metadata": {},
}
],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
mock_fixer.get_fixes_applied.return_value = []
mock_validator = MagicMock()
mock_validator.validate.return_value = (True, None)
mock_validator.errors = []
with (
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
):
result = await tool._execute(
user_id=None,
session=session,
agent_json=agent_json,
save=True,
)
assert isinstance(result, ErrorResponse)
assert "logged in" in result.message.lower()

View File

@@ -10,13 +10,34 @@ Lifecycle
---------
1. **Turn start** connect to the existing sandbox (sandbox_id in Redis) or
create a new one via ``get_or_create_sandbox()``.
``connect()`` in e2b v2 auto-resumes paused sandboxes.
2. **Execution** ``bash_exec`` and MCP file tools operate directly on the
sandbox's ``/home/user`` filesystem.
3. **Session expiry** E2B sandbox is killed by its own timeout (session_ttl).
3. **Turn end** the sandbox is paused via ``pause_sandbox()`` (fire-and-forget)
so idle time between turns costs nothing. Paused sandboxes have no compute
cost.
4. **Session delete** ``kill_sandbox()`` fully terminates the sandbox.
Cost control
------------
Sandboxes are created with a configurable ``on_timeout`` lifecycle action
(default: ``"pause"``). The explicit per-turn ``pause_sandbox()`` call is the
primary mechanism; the lifecycle setting is a safety net. Paused sandboxes are
free.
The sandbox_id is stored in Redis. The same key doubles as a creation lock:
a ``"creating"`` sentinel value is written with a short TTL while a new sandbox
is being provisioned, preventing duplicate creation under concurrent requests.
E2B project-level "paused sandbox lifetime" should be set to match
``_SANDBOX_ID_TTL`` (48 h) so orphaned paused sandboxes are auto-killed before
the Redis key expires.
"""
import asyncio
import contextlib
import logging
from typing import Any, Awaitable, Callable, Literal
from e2b import AsyncSandbox
@@ -24,147 +45,245 @@ from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
_SANDBOX_REDIS_PREFIX = "copilot:e2b:sandbox:"
E2B_WORKDIR = "/home/user"
_CREATING = "__creating__"
_CREATION_LOCK_TTL = 60
_MAX_WAIT_ATTEMPTS = 20 # 20 * 0.5s = 10s max wait
_SANDBOX_KEY_PREFIX = "copilot:e2b:sandbox:"
_CREATING_SENTINEL = "creating"
# Short TTL for the "creating" sentinel — if the process dies mid-creation the
# lock auto-expires so other callers are not blocked forever.
_CREATION_LOCK_TTL = 60 # seconds
_MAX_WAIT_ATTEMPTS = 20 # 20 × 0.5 s = 10 s max wait
# Timeout for E2B API calls (pause/kill) — short because these are control-plane
# operations; if the sandbox is unreachable, fail fast and retry on the next turn.
_E2B_API_TIMEOUT_SECONDS = 10
# Redis TTL for the sandbox key. Must be ≥ the E2B project "paused sandbox
# lifetime" setting (recommended: set both to 48 h).
_SANDBOX_ID_TTL = 48 * 3600 # 48 hours
def _sandbox_key(session_id: str) -> str:
return f"{_SANDBOX_KEY_PREFIX}{session_id}"
async def _get_stored_sandbox_id(session_id: str) -> str | None:
redis = await get_redis_async()
raw = await redis.get(_sandbox_key(session_id))
value = raw.decode() if isinstance(raw, bytes) else raw
return None if value == _CREATING_SENTINEL else value
async def _set_stored_sandbox_id(session_id: str, sandbox_id: str) -> None:
redis = await get_redis_async()
await redis.set(_sandbox_key(session_id), sandbox_id, ex=_SANDBOX_ID_TTL)
async def _clear_stored_sandbox_id(session_id: str) -> None:
redis = await get_redis_async()
await redis.delete(_sandbox_key(session_id))
async def _try_reconnect(
sandbox_id: str, api_key: str, redis_key: str, timeout: int
sandbox_id: str, session_id: str, api_key: str
) -> "AsyncSandbox | None":
"""Try to reconnect to an existing sandbox. Returns None on failure."""
try:
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
if await sandbox.is_running():
redis = await get_redis_async()
await redis.expire(redis_key, timeout)
# Refresh TTL so an active session cannot lose its sandbox_id at expiry.
await _set_stored_sandbox_id(session_id, sandbox_id)
return sandbox
except Exception as exc:
logger.warning("[E2B] Reconnect to %.12s failed: %s", sandbox_id, exc)
# Stale — clear Redis so a new sandbox can be created.
redis = await get_redis_async()
await redis.delete(redis_key)
# Stale — clear the sandbox_id from Redis so a new one can be created.
await _clear_stored_sandbox_id(session_id)
return None
async def get_or_create_sandbox(
session_id: str,
api_key: str,
timeout: int,
template: str = "base",
timeout: int = 43200,
on_timeout: Literal["kill", "pause"] = "pause",
) -> AsyncSandbox:
"""Return the existing E2B sandbox for *session_id* or create a new one.
The sandbox_id is persisted in Redis so the same sandbox is reused
across turns. Concurrent calls for the same session are serialised
via a Redis ``SET NX`` creation lock.
The sandbox key in Redis serves a dual purpose: it stores the sandbox_id
and acts as a creation lock via a ``"creating"`` sentinel value. This
removes the need for a separate lock key.
*timeout* controls how long the e2b sandbox may run continuously before
the ``on_timeout`` lifecycle rule fires (default: 3 h).
*on_timeout* controls what happens on timeout: ``"pause"`` (default, free)
or ``"kill"``.
"""
redis = await get_redis_async()
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
key = _sandbox_key(session_id)
# 1. Try reconnecting to an existing sandbox.
raw = await redis.get(redis_key)
if raw:
sandbox_id = raw if isinstance(raw, str) else raw.decode()
if sandbox_id != _CREATING:
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
for _ in range(_MAX_WAIT_ATTEMPTS):
raw = await redis.get(key)
value = raw.decode() if isinstance(raw, bytes) else raw
if value and value != _CREATING_SENTINEL:
# Existing sandbox ID — try to reconnect (auto-resumes if paused).
sandbox = await _try_reconnect(value, session_id, api_key)
if sandbox:
logger.info(
"[E2B] Reconnected to %.12s for session %.12s",
sandbox_id,
value,
session_id,
)
return sandbox
# _try_reconnect cleared the key — loop to create a new sandbox.
continue
# 2. Claim creation lock. If another request holds it, wait for the result.
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
if not claimed:
for _ in range(_MAX_WAIT_ATTEMPTS):
if value == _CREATING_SENTINEL:
# Another coroutine is creating — wait for it to finish.
await asyncio.sleep(0.5)
raw = await redis.get(redis_key)
if not raw:
break # Lock expired — fall through to retry creation
sandbox_id = raw if isinstance(raw, str) else raw.decode()
if sandbox_id != _CREATING:
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
if sandbox:
return sandbox
break # Stale sandbox cleared — fall through to create
continue
# Try to claim creation lock again after waiting.
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
if not claimed:
# Another process may have created a sandbox — try to use it.
raw = await redis.get(redis_key)
if raw:
sandbox_id = raw if isinstance(raw, str) else raw.decode()
if sandbox_id != _CREATING:
sandbox = await _try_reconnect(
sandbox_id, api_key, redis_key, timeout
)
if sandbox:
return sandbox
raise RuntimeError(
f"Could not acquire E2B creation lock for session {session_id[:12]}"
)
# 3. Create a new sandbox.
try:
sandbox = await AsyncSandbox.create(
template=template, api_key=api_key, timeout=timeout
# No sandbox and no active creation — atomically claim the creation slot.
claimed = await redis.set(
key, _CREATING_SENTINEL, nx=True, ex=_CREATION_LOCK_TTL
)
except Exception:
await redis.delete(redis_key)
raise
if not claimed:
# Race lost — another coroutine just claimed it.
await asyncio.sleep(0.1)
continue
await redis.setex(redis_key, timeout, sandbox.sandbox_id)
logger.info(
"[E2B] Created sandbox %.12s for session %.12s",
sandbox.sandbox_id,
session_id,
)
return sandbox
# We hold the slot — create the sandbox.
try:
sandbox = await AsyncSandbox.create(
template=template,
api_key=api_key,
timeout=timeout,
lifecycle={"on_timeout": on_timeout},
)
try:
await _set_stored_sandbox_id(session_id, sandbox.sandbox_id)
except Exception:
# Redis save failed — kill the sandbox to avoid leaking it.
with contextlib.suppress(Exception):
await sandbox.kill()
raise
except Exception:
# Release the creation slot so other callers can proceed.
await redis.delete(key)
raise
logger.info(
"[E2B] Created sandbox %.12s for session %.12s",
sandbox.sandbox_id,
session_id,
)
return sandbox
raise RuntimeError(f"Could not acquire E2B sandbox for session {session_id[:12]}")
async def kill_sandbox(session_id: str, api_key: str) -> bool:
"""Kill the E2B sandbox for *session_id* and clean up its Redis entry.
async def _act_on_sandbox(
session_id: str,
api_key: str,
action: str,
fn: Callable[[AsyncSandbox], Awaitable[Any]],
*,
clear_stored_id: bool = False,
) -> bool:
"""Connect to the sandbox for *session_id* and run *fn* on it.
Returns ``True`` if a sandbox was found and killed, ``False`` otherwise.
Safe to call even when no sandbox exists for the session.
Shared by ``pause_sandbox`` and ``kill_sandbox``. Returns ``True`` on
success, ``False`` when no sandbox is found or the action fails.
If *clear_stored_id* is ``True``, the sandbox_id is removed from Redis
only after the action succeeds so a failed kill can be retried.
"""
redis = await get_redis_async()
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
raw = await redis.get(redis_key)
if not raw:
sandbox_id = await _get_stored_sandbox_id(session_id)
if not sandbox_id:
return False
sandbox_id = raw if isinstance(raw, str) else raw.decode()
await redis.delete(redis_key)
if sandbox_id == _CREATING:
return False
async def _run() -> None:
await fn(await AsyncSandbox.connect(sandbox_id, api_key=api_key))
try:
async def _connect_and_kill():
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
await sandbox.kill()
await asyncio.wait_for(_connect_and_kill(), timeout=10)
await asyncio.wait_for(_run(), timeout=_E2B_API_TIMEOUT_SECONDS)
if clear_stored_id:
await _clear_stored_sandbox_id(session_id)
logger.info(
"[E2B] Killed sandbox %.12s for session %.12s",
"[E2B] %s sandbox %.12s for session %.12s",
action.capitalize(),
sandbox_id,
session_id,
)
return True
except Exception as exc:
logger.warning(
"[E2B] Failed to kill sandbox %.12s for session %.12s: %s",
"[E2B] Failed to %s sandbox %.12s for session %.12s: %s",
action,
sandbox_id,
session_id,
exc,
)
return False
async def pause_sandbox(session_id: str, api_key: str) -> bool:
"""Pause the E2B sandbox for *session_id* to stop billing between turns.
Paused sandboxes cost nothing and are resumed automatically by
``get_or_create_sandbox()`` on the next turn (via ``AsyncSandbox.connect()``).
The sandbox_id is kept in Redis so reconnection works seamlessly.
Prefer ``pause_sandbox_direct()`` when the sandbox object is already in
scope — it skips the Redis lookup and reconnect round-trip.
Returns ``True`` if the sandbox was found and paused, ``False`` otherwise.
Safe to call even when no sandbox exists for the session.
"""
return await _act_on_sandbox(session_id, api_key, "pause", lambda sb: sb.pause())
async def pause_sandbox_direct(sandbox: "AsyncSandbox", session_id: str) -> bool:
"""Pause an already-connected sandbox without a reconnect round-trip.
Use this in callers that already hold the live sandbox object (e.g. turn
teardown in ``service.py``). Saves the Redis lookup and
``AsyncSandbox.connect()`` call that ``pause_sandbox()`` would make.
Returns ``True`` on success, ``False`` on failure or timeout.
"""
try:
await asyncio.wait_for(sandbox.pause(), timeout=_E2B_API_TIMEOUT_SECONDS)
logger.info(
"[E2B] Paused sandbox %.12s for session %.12s",
sandbox.sandbox_id,
session_id,
)
return True
except Exception as exc:
logger.warning(
"[E2B] Failed to pause sandbox %.12s for session %.12s: %s",
sandbox.sandbox_id,
session_id,
exc,
)
return False
async def kill_sandbox(
session_id: str,
api_key: str,
) -> bool:
"""Kill the E2B sandbox for *session_id* and clear its Redis entry.
Returns ``True`` if a sandbox was found and killed, ``False`` otherwise.
Safe to call even when no sandbox exists for the session.
"""
return await _act_on_sandbox(
session_id,
api_key,
"kill",
lambda sb: sb.kill(),
clear_stored_id=True,
)

View File

@@ -1,6 +1,12 @@
"""Tests for e2b_sandbox: get_or_create_sandbox, _try_reconnect, kill_sandbox.
Uses mock Redis and mock AsyncSandbox — no external dependencies.
sandbox_id is stored in Redis under _SANDBOX_KEY_PREFIX + session_id.
The same key doubles as a creation lock via a "creating" sentinel value.
Tests mock:
- ``get_redis_async`` (sandbox key storage + creation lock sentinel)
- ``AsyncSandbox`` (E2B SDK)
Tests are synchronous (using asyncio.run) to avoid conflicts with the
session-scoped event loop in conftest.py.
"""
@@ -11,36 +17,50 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from .e2b_sandbox import (
_CREATING,
_SANDBOX_REDIS_PREFIX,
_CREATING_SENTINEL,
_try_reconnect,
get_or_create_sandbox,
kill_sandbox,
pause_sandbox,
pause_sandbox_direct,
)
_KEY = f"{_SANDBOX_REDIS_PREFIX}sess-123"
_SESSION_ID = "sess-123"
_API_KEY = "test-api-key"
_SANDBOX_ID = "sb-abc"
_TIMEOUT = 300
def _mock_sandbox(sandbox_id: str = "sb-abc", running: bool = True) -> MagicMock:
def _mock_sandbox(sandbox_id: str = _SANDBOX_ID, running: bool = True) -> MagicMock:
sb = MagicMock()
sb.sandbox_id = sandbox_id
sb.is_running = AsyncMock(return_value=running)
sb.pause = AsyncMock()
sb.kill = AsyncMock()
return sb
def _mock_redis(get_val: str | bytes | None = None, set_nx_result: bool = True):
def _mock_redis(
set_nx_result: bool = True,
stored_sandbox_id: str | None = None,
) -> AsyncMock:
"""Create a mock redis client.
*stored_sandbox_id* is returned by ``get()`` calls (simulates the sandbox_id
stored under the ``_SANDBOX_KEY_PREFIX`` key). ``set_nx_result`` controls
whether the creation-slot ``SET NX`` succeeds.
If *stored_sandbox_id* is None the key is absent (no sandbox, no lock).
"""
r = AsyncMock()
r.get = AsyncMock(return_value=get_val)
raw = stored_sandbox_id.encode() if stored_sandbox_id else None
r.get = AsyncMock(return_value=raw)
r.set = AsyncMock(return_value=set_nx_result)
r.setex = AsyncMock()
r.delete = AsyncMock()
r.expire = AsyncMock()
return r
def _patch_redis(redis):
def _patch_redis(redis: AsyncMock):
return patch(
"backend.copilot.tools.e2b_sandbox.get_redis_async",
new_callable=AsyncMock,
@@ -55,6 +75,7 @@ def _patch_redis(redis):
class TestTryReconnect:
def test_reconnect_success(self):
"""Returns the sandbox when it connects and is running; refreshes Redis TTL."""
sb = _mock_sandbox()
redis = _mock_redis()
with (
@@ -62,36 +83,39 @@ class TestTryReconnect:
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
result = asyncio.run(_try_reconnect(_SANDBOX_ID, _SESSION_ID, _API_KEY))
assert result is sb
redis.expire.assert_awaited_once_with(_KEY, _TIMEOUT)
redis.delete.assert_not_awaited()
# TTL must be refreshed so an active session cannot lose its key at expiry.
redis.set.assert_awaited_once()
def test_reconnect_not_running_clears_key(self):
def test_reconnect_not_running_clears_redis(self):
"""Clears sandbox_id in Redis when the sandbox is no longer running."""
sb = _mock_sandbox(running=False)
redis = _mock_redis()
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
result = asyncio.run(_try_reconnect(_SANDBOX_ID, _SESSION_ID, _API_KEY))
assert result is None
redis.delete.assert_awaited_once_with(_KEY)
redis.delete.assert_awaited_once()
def test_reconnect_exception_clears_key(self):
redis = _mock_redis()
def test_reconnect_exception_clears_redis(self):
"""Clears sandbox_id in Redis when connect raises an exception."""
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
result = asyncio.run(_try_reconnect(_SANDBOX_ID, _SESSION_ID, _API_KEY))
assert result is None
redis.delete.assert_awaited_once_with(_KEY)
redis.delete.assert_awaited_once()
# ---------------------------------------------------------------------------
@@ -103,38 +127,63 @@ class TestGetOrCreateSandbox:
def test_reconnect_existing(self):
"""When Redis has a valid sandbox_id, reconnect to it."""
sb = _mock_sandbox()
redis = _mock_redis(get_val="sb-abc")
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
mock_cls.create.assert_not_called()
# redis.set called once to refresh TTL, not to claim a creation slot
redis.set.assert_awaited_once()
def test_create_new_when_no_key(self):
"""When Redis is empty, claim lock and create a new sandbox."""
sb = _mock_sandbox("sb-new")
redis = _mock_redis(get_val=None, set_nx_result=True)
def test_create_new_when_no_stored_id(self):
"""When Redis has no sandbox_id, claim slot and create a new sandbox."""
new_sb = _mock_sandbox("sb-new")
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.create = AsyncMock(return_value=sb)
mock_cls.create = AsyncMock(return_value=new_sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
redis.setex.assert_awaited_once_with(_KEY, _TIMEOUT, "sb-new")
assert result is new_sb
mock_cls.create.assert_awaited_once()
# Verify lifecycle param is set
_, kwargs = mock_cls.create.call_args
assert kwargs.get("lifecycle") == {"on_timeout": "pause"}
# sandbox_id should be saved to Redis
redis.set.assert_awaited()
def test_create_failure_clears_lock(self):
"""If sandbox creation fails, the Redis lock is deleted."""
redis = _mock_redis(get_val=None, set_nx_result=True)
def test_create_with_on_timeout_kill(self):
"""on_timeout='kill' is passed through to AsyncSandbox.create."""
new_sb = _mock_sandbox("sb-new")
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.create = AsyncMock(return_value=new_sb)
asyncio.run(
get_or_create_sandbox(
_SESSION_ID, _API_KEY, timeout=_TIMEOUT, on_timeout="kill"
)
)
_, kwargs = mock_cls.create.call_args
assert kwargs.get("lifecycle") == {"on_timeout": "kill"}
def test_create_failure_releases_slot(self):
"""If sandbox creation fails, the Redis creation slot is deleted."""
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
@@ -142,17 +191,53 @@ class TestGetOrCreateSandbox:
mock_cls.create = AsyncMock(side_effect=RuntimeError("quota"))
with pytest.raises(RuntimeError, match="quota"):
asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
redis.delete.assert_awaited_once_with(_KEY)
redis.delete.assert_awaited_once()
def test_wait_for_lock_then_reconnect(self):
"""When another process holds the lock, wait and reconnect."""
def test_redis_save_failure_kills_sandbox_and_releases_slot(self):
"""If Redis save fails after creation, sandbox is killed and slot released."""
new_sb = _mock_sandbox("sb-new")
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
# First set() call = creation slot SET NX (returns True).
# Second set() call = sandbox_id save (raises).
call_count = 0
async def _set_side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return True # creation slot claimed
raise RuntimeError("redis error")
redis.set = AsyncMock(side_effect=_set_side_effect)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.create = AsyncMock(return_value=new_sb)
with pytest.raises(RuntimeError, match="redis error"):
asyncio.run(
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
# Sandbox must be killed to avoid leaking it
new_sb.kill.assert_awaited_once()
# Creation slot must always be released
redis.delete.assert_awaited_once()
def test_wait_for_creating_sentinel_then_reconnect(self):
"""When the key holds the 'creating' sentinel, wait then reconnect."""
sb = _mock_sandbox("sb-other")
redis = _mock_redis()
redis.get = AsyncMock(side_effect=[_CREATING, "sb-other"])
# First get() returns the sentinel; second returns the real ID.
redis = AsyncMock()
creating_raw = _CREATING_SENTINEL.encode()
redis.get = AsyncMock(side_effect=[creating_raw, b"sb-other"])
redis.set = AsyncMock(return_value=False)
redis.delete = AsyncMock()
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
@@ -163,16 +248,21 @@ class TestGetOrCreateSandbox:
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
def test_stale_reconnect_clears_and_creates(self):
"""When stored sandbox is stale, clear key and create a new one."""
"""When stored sandbox is stale (not running), clear it and create a new one."""
stale_sb = _mock_sandbox("sb-stale", running=False)
new_sb = _mock_sandbox("sb-fresh")
redis = _mock_redis(get_val="sb-stale", set_nx_result=True)
# First get() returns stale id (for reconnect check), then None (after clear).
redis = AsyncMock()
redis.get = AsyncMock(side_effect=[b"sb-stale", None])
redis.set = AsyncMock(return_value=True)
redis.delete = AsyncMock()
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
@@ -180,10 +270,11 @@ class TestGetOrCreateSandbox:
mock_cls.connect = AsyncMock(return_value=stale_sb)
mock_cls.create = AsyncMock(return_value=new_sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
assert result is new_sb
# Redis delete called at least once to clear stale id
redis.delete.assert_awaited()
@@ -194,70 +285,48 @@ class TestGetOrCreateSandbox:
class TestKillSandbox:
def test_kill_existing_sandbox(self):
"""Kill a running sandbox and clean up Redis."""
"""Kill a running sandbox and clear its Redis entry."""
sb = _mock_sandbox()
sb.kill = AsyncMock()
redis = _mock_redis(get_val="sb-abc")
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
assert result is True
redis.delete.assert_awaited_once_with(_KEY)
sb.kill.assert_awaited_once()
# Redis key cleared after successful kill
redis.delete.assert_awaited_once()
def test_kill_no_sandbox(self):
"""No-op when no sandbox exists in Redis."""
redis = _mock_redis(get_val=None)
"""No-op when Redis has no sandbox_id."""
redis = _mock_redis(stored_sandbox_id=None)
with _patch_redis(redis):
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
assert result is False
redis.delete.assert_not_awaited()
def test_kill_creating_state(self):
"""Clears Redis key but returns False when sandbox is still being created."""
redis = _mock_redis(get_val=_CREATING)
with _patch_redis(redis):
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
def test_kill_connect_failure_keeps_redis(self):
"""Returns False and leaves Redis entry intact when connect/kill fails.
assert result is False
redis.delete.assert_awaited_once_with(_KEY)
def test_kill_connect_failure(self):
"""Returns False and cleans Redis if connect/kill fails."""
redis = _mock_redis(get_val="sb-abc")
Keeping the sandbox_id in Redis allows the kill to be retried.
"""
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
assert result is False
redis.delete.assert_awaited_once_with(_KEY)
redis.delete.assert_not_awaited()
def test_kill_with_bytes_redis_value(self):
"""Redis may return bytes — kill_sandbox should decode correctly."""
sb = _mock_sandbox()
sb.kill = AsyncMock()
redis = _mock_redis(get_val=b"sb-abc")
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is True
sb.kill.assert_awaited_once()
def test_kill_timeout_returns_false(self):
"""Returns False when E2B API calls exceed the 10s timeout."""
redis = _mock_redis(get_val="sb-abc")
def test_kill_timeout_keeps_redis(self):
"""Returns False and leaves Redis entry intact when the E2B call times out."""
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
_patch_redis(redis),
patch(
@@ -266,7 +335,146 @@ class TestKillSandbox:
side_effect=asyncio.TimeoutError,
),
):
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
assert result is False
redis.delete.assert_not_awaited()
def test_kill_creating_sentinel_returns_false(self):
"""No-op when the key holds the 'creating' sentinel (no real sandbox yet)."""
redis = _mock_redis(stored_sandbox_id=_CREATING_SENTINEL)
with _patch_redis(redis):
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
assert result is False
# ---------------------------------------------------------------------------
# pause_sandbox
# ---------------------------------------------------------------------------
class TestPauseSandbox:
def test_pause_existing_sandbox(self):
"""Pause a running sandbox; Redis sandbox_id is preserved."""
sb = _mock_sandbox()
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
assert result is True
sb.pause.assert_awaited_once()
# sandbox_id should remain in Redis (not cleared on pause)
redis.delete.assert_not_awaited()
def test_pause_no_sandbox(self):
"""No-op when Redis has no sandbox_id."""
redis = _mock_redis(stored_sandbox_id=None)
with _patch_redis(redis):
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
assert result is False
def test_pause_connect_failure(self):
"""Returns False if connect fails."""
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
assert result is False
def test_pause_creating_sentinel_returns_false(self):
"""No-op when the key holds the 'creating' sentinel (no real sandbox yet)."""
redis = _mock_redis(stored_sandbox_id=_CREATING_SENTINEL)
with _patch_redis(redis):
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
assert result is False
def test_pause_timeout_returns_false(self):
"""Returns False and preserves Redis entry when the E2B API call times out."""
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
_patch_redis(redis),
patch(
"backend.copilot.tools.e2b_sandbox.asyncio.wait_for",
new_callable=AsyncMock,
side_effect=asyncio.TimeoutError,
),
):
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
assert result is False
# sandbox_id must remain in Redis so the next turn can reconnect
redis.delete.assert_not_awaited()
def test_pause_then_reconnect_reuses_sandbox(self):
"""After pause, get_or_create_sandbox reconnects the same sandbox.
Covers the pause->reconnect cycle: connect() auto-resumes a paused
sandbox, and is_running() returns True once resume completes, so the
same sandbox_id is reused rather than a new one being created.
"""
sb = _mock_sandbox(_SANDBOX_ID)
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
# Step 1: pause the sandbox
paused = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
assert paused is True
sb.pause.assert_awaited_once()
# Step 2: reconnect on next turn -- same sandbox should be returned
result = asyncio.run(
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
mock_cls.create.assert_not_called()
# ---------------------------------------------------------------------------
# pause_sandbox_direct
# ---------------------------------------------------------------------------
class TestPauseSandboxDirect:
def test_pause_direct_success(self):
"""Pauses the sandbox directly without a Redis lookup or reconnect."""
sb = _mock_sandbox()
result = asyncio.run(pause_sandbox_direct(sb, _SESSION_ID))
assert result is True
sb.pause.assert_awaited_once()
def test_pause_direct_failure_returns_false(self):
"""Returns False when sandbox.pause() raises."""
sb = _mock_sandbox()
sb.pause = AsyncMock(side_effect=RuntimeError("e2b error"))
result = asyncio.run(pause_sandbox_direct(sb, _SESSION_ID))
assert result is False
def test_pause_direct_timeout_returns_false(self):
"""Returns False when sandbox.pause() exceeds the 10s timeout."""
sb = _mock_sandbox()
with patch(
"backend.copilot.tools.e2b_sandbox.asyncio.wait_for",
new_callable=AsyncMock,
side_effect=asyncio.TimeoutError,
):
result = asyncio.run(pause_sandbox_direct(sb, _SESSION_ID))
assert result is False
redis.delete.assert_awaited_once_with(_KEY)

View File

@@ -1,32 +1,20 @@
"""EditAgentTool - Edits existing agents using natural language."""
"""EditAgentTool - Edits existing agents using pre-built JSON."""
import logging
from typing import Any
from backend.copilot.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,
generate_agent_patch,
get_agent_as_json,
get_user_message_for_error,
save_agent_to_library,
)
from .agent_generator import get_agent_as_json
from .agent_generator.pipeline import fetch_library_agents, fix_validate_and_save
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
ToolResponseBase,
)
from .models import ErrorResponse, ToolResponseBase
logger = logging.getLogger(__name__)
class EditAgentTool(BaseTool):
"""Tool for editing existing agents using natural language."""
"""Tool for editing existing agents using pre-built JSON."""
@property
def name(self) -> str:
@@ -35,11 +23,12 @@ class EditAgentTool(BaseTool):
@property
def description(self) -> str:
return (
"Edit an existing agent from the user's library using natural language. "
"Generates updates to the agent while preserving unchanged parts. "
"\n\nIMPORTANT: Before calling this tool, if the changes involve adding new "
"Edit an existing agent. Pass `agent_json` with the complete "
"updated agent JSON you generated. The tool validates, auto-fixes, "
"and saves.\n\n"
"IMPORTANT: Before calling this tool, if the changes involve adding new "
"functionality, search for relevant existing agents using find_library_agent "
"that could be used as building blocks. Pass their IDs in library_agent_ids."
"that could be used as building blocks."
)
@property
@@ -58,26 +47,20 @@ class EditAgentTool(BaseTool):
"Can be a graph ID or library agent ID."
),
},
"changes": {
"type": "string",
"agent_json": {
"type": "object",
"description": (
"Natural language description of what changes to make. "
"Be specific about what to add, remove, or modify."
),
},
"context": {
"type": "string",
"description": (
"Additional context or answers to previous clarifying questions."
"Complete updated agent JSON to validate and save. "
"Must contain 'nodes' and 'links'. "
"Include 'name' and/or 'description' if they need "
"to be updated."
),
},
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of library agent IDs to use as building blocks for the changes. "
"If adding new functionality, search for relevant agents using "
"find_library_agent first, then pass their IDs here."
"List of library agent IDs to use as building blocks for the changes."
),
},
"save": {
@@ -89,7 +72,7 @@ class EditAgentTool(BaseTool):
"default": True,
},
},
"required": ["agent_id", "changes"],
"required": ["agent_id", "agent_json"],
}
async def _execute(
@@ -98,36 +81,39 @@ class EditAgentTool(BaseTool):
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the edit_agent tool.
Flow:
1. Fetch the current agent
2. Generate updated agent (external service handles fixing and validation)
3. Preview or save based on the save parameter
"""
agent_id = kwargs.get("agent_id", "").strip()
changes = kwargs.get("changes", "").strip()
context = kwargs.get("context", "")
library_agent_ids = kwargs.get("library_agent_ids", [])
save = kwargs.get("save", True)
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
session_id = session.session_id if session else None
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
error="Missing agent_id parameter",
error="missing_agent_id",
session_id=session_id,
)
if not changes:
if not agent_json:
return ErrorResponse(
message="Please describe what changes you want to make.",
error="Missing changes parameter",
message=(
"Please provide agent_json with the complete updated agent graph."
),
error="missing_agent_json",
session_id=session_id,
)
current_agent = await get_agent_as_json(agent_id, user_id)
save = kwargs.get("save", True)
library_agent_ids = kwargs.get("library_agent_ids", [])
nodes = agent_json.get("nodes", [])
if not nodes:
return ErrorResponse(
message="The agent JSON has no nodes.",
error="empty_agent",
session_id=session_id,
)
# Preserve original agent's ID
current_agent = await get_agent_as_json(agent_id, user_id)
if current_agent is None:
return ErrorResponse(
message=f"Could not find agent with ID '{agent_id}' in your library.",
@@ -135,142 +121,19 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Fetch library agents by IDs if provided
library_agents = None
if user_id and library_agent_ids:
try:
from .agent_generator import get_library_agents_by_ids
agent_json["id"] = current_agent.get("id", agent_id)
agent_json["version"] = current_agent.get("version", 1)
agent_json.setdefault("is_active", True)
graph_id = current_agent.get("id")
# Filter out the current agent being edited
filtered_ids = [id for id in library_agent_ids if id != graph_id]
# Fetch library agents for AgentExecutorBlock validation
library_agents = await fetch_library_agents(user_id, library_agent_ids)
library_agents = await get_library_agents_by_ids(
user_id=user_id,
agent_ids=filtered_ids,
)
logger.debug(
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents by IDs: {e}")
update_request = changes
if context:
update_request = f"{changes}\n\nAdditional context:\n{context}"
try:
result = await generate_agent_patch(
update_request,
current_agent,
library_agents,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
"Agent editing is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
if result is None:
return ErrorResponse(
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.",
error="update_generation_failed",
details={"agent_id": agent_id, "changes": changes[:100]},
session_id=session_id,
)
# Check if the result is an error from the external service
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="generate the changes",
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
validation_message="The generated changes failed validation. Please try rephrasing your request.",
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"update_generation_failed:{error_type}",
details={
"agent_id": agent_id,
"changes": changes[:100],
"service_error": error_msg,
"error_type": error_type,
},
session_id=session_id,
)
if result.get("type") == "clarifying_questions":
questions = result.get("questions", [])
return ClarificationNeededResponse(
message=(
"I need some more information about the changes. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
],
session_id=session_id,
)
updated_agent = result
agent_name = updated_agent.get("name", "Updated Agent")
agent_description = updated_agent.get("description", "")
node_count = len(updated_agent.get("nodes", []))
link_count = len(updated_agent.get("links", []))
if not save:
return AgentPreviewResponse(
message=(
f"I've updated the agent. "
f"The agent now has {node_count} blocks. "
f"Review it and call edit_agent with save=true to save the changes."
),
agent_json=updated_agent,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
try:
created_graph, library_agent = await save_agent_to_library(
updated_agent, user_id, is_update=True
)
return AgentSavedResponse(
message=f"Updated agent '{created_graph.name}' has been saved to your library!",
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to save the updated agent: {str(e)}",
error="save_failed",
details={"exception": str(e)},
session_id=session_id,
)
return await fix_validate_and_save(
agent_json,
user_id=user_id,
session_id=session_id,
save=save,
is_update=True,
default_name="Updated Agent",
library_agents=library_agents,
)

View File

@@ -32,7 +32,7 @@ COPILOT_EXCLUDED_BLOCK_TYPES = {
BlockType.NOTE, # Visual annotation only - no runtime behavior
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
BlockType.MCP_TOOL, # Has dedicated run_mcp_tool tool with proper discovery + auth flow
BlockType.MCP_TOOL, # Has dedicated run_mcp_tool tool with discovery + auth flow
}
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
@@ -72,6 +72,15 @@ class FindBlockTool(BaseTool):
"Use keywords like 'email', 'http', 'text', 'ai', etc."
),
},
"include_schemas": {
"type": "boolean",
"description": (
"If true, include full input_schema and output_schema "
"for each block. Use when generating agent JSON that "
"needs block schemas. Default is false."
),
"default": False,
},
},
"required": ["query"],
}
@@ -99,6 +108,7 @@ class FindBlockTool(BaseTool):
ErrorResponse: Error message
"""
query = kwargs.get("query", "").strip()
include_schemas = kwargs.get("include_schemas", False)
session_id = session.session_id
if not query:
@@ -143,15 +153,21 @@ class FindBlockTool(BaseTool):
):
continue
blocks.append(
BlockInfoSummary(
id=block_id,
name=block.name,
description=block.description or "",
categories=[c.value for c in block.categories],
)
summary = BlockInfoSummary(
id=block_id,
name=block.name,
description=block.optimized_description or block.description or "",
categories=[c.value for c in block.categories],
)
if include_schemas:
info = block.get_info()
summary.input_schema = info.inputSchema
summary.output_schema = info.outputSchema
summary.static_output = info.staticOutput
blocks.append(summary)
if len(blocks) >= _TARGET_RESULTS:
break

View File

@@ -25,6 +25,7 @@ def make_mock_block(
input_schema: dict | None = None,
output_schema: dict | None = None,
credentials_fields: dict | None = None,
static_output: bool = False,
):
"""Create a mock block for testing."""
mock = MagicMock()
@@ -33,6 +34,7 @@ def make_mock_block(
mock.description = f"{name} description"
mock.block_type = block_type
mock.disabled = disabled
mock.static_output = static_output
mock.input_schema = MagicMock()
mock.input_schema.jsonschema.return_value = input_schema or {
"properties": {},
@@ -42,6 +44,15 @@ def make_mock_block(
mock.output_schema = MagicMock()
mock.output_schema.jsonschema.return_value = output_schema or {}
mock.categories = []
mock.optimized_description = None
# Mock get_info() for include_schemas support
mock_info = MagicMock()
mock_info.inputSchema = input_schema or {"properties": {}, "required": []}
mock_info.outputSchema = output_schema or {}
mock_info.staticOutput = static_output
mock.get_info.return_value = mock_info
return mock
@@ -399,3 +410,92 @@ class TestFindBlockFiltering:
f"Average chars per block ({avg_chars}) exceeds 500. "
f"Total response: {total_chars} chars for {response.count} blocks."
)
@pytest.mark.asyncio(loop_scope="session")
async def test_include_schemas_false_omits_schemas(self):
"""Without include_schemas, schemas should be empty dicts."""
session = make_session(user_id=_TEST_USER_ID)
input_schema = {"properties": {"url": {"type": "string"}}, "required": ["url"]}
output_schema = {"properties": {"result": {"type": "string"}}}
search_results = [{"content_id": "block-1", "score": 0.9}]
block = make_mock_block(
"block-1",
"Test Block",
BlockType.STANDARD,
input_schema=input_schema,
output_schema=output_schema,
)
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, 1)
)
with (
patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
),
patch(
"backend.copilot.tools.find_block.get_block",
return_value=block,
),
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query="test",
include_schemas=False,
)
assert isinstance(response, BlockListResponse)
assert response.blocks[0].input_schema == {}
assert response.blocks[0].output_schema == {}
assert response.blocks[0].static_output is False
@pytest.mark.asyncio(loop_scope="session")
async def test_include_schemas_true_populates_schemas(self):
"""With include_schemas=true, schemas should be populated from block info."""
session = make_session(user_id=_TEST_USER_ID)
input_schema = {"properties": {"url": {"type": "string"}}, "required": ["url"]}
output_schema = {"properties": {"result": {"type": "string"}}}
search_results = [{"content_id": "block-1", "score": 0.9}]
block = make_mock_block(
"block-1",
"Test Block",
BlockType.STANDARD,
input_schema=input_schema,
output_schema=output_schema,
static_output=True,
)
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, 1)
)
with (
patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
),
patch(
"backend.copilot.tools.find_block.get_block",
return_value=block,
),
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query="test",
include_schemas=True,
)
assert isinstance(response, BlockListResponse)
assert response.blocks[0].input_schema == input_schema
assert response.blocks[0].output_schema == output_schema
assert response.blocks[0].static_output is True

View File

@@ -22,6 +22,9 @@ class FindLibraryAgentTool(BaseTool):
"Search for or list agents in the user's library. Use this to find "
"agents the user has already added to their library, including agents "
"they created or added from the marketplace. "
"When creating agents with sub-agent composition, use this to get "
"the agent's graph_id, graph_version, input_schema, and output_schema "
"needed for AgentExecutorBlock nodes. "
"Omit the query to list all agents."
)

View File

@@ -0,0 +1,134 @@
"""FixAgentGraphTool - Auto-fixes common agent JSON issues."""
import logging
from typing import Any
from backend.copilot.model import ChatSession
from .agent_generator.validation import AgentFixer, AgentValidator, get_blocks_as_dicts
from .base import BaseTool
from .models import ErrorResponse, FixResultResponse, ToolResponseBase
logger = logging.getLogger(__name__)
class FixAgentGraphTool(BaseTool):
"""Tool for auto-fixing common issues in agent JSON graphs."""
@property
def name(self) -> str:
return "fix_agent_graph"
@property
def description(self) -> str:
return (
"Auto-fix common issues in an agent JSON graph. Applies fixes for:\n"
"- Missing or invalid UUIDs on nodes and links\n"
"- StoreValueBlock prerequisites for ConditionBlock\n"
"- Double curly brace escaping in prompt templates\n"
"- AddToList/AddToDictionary prerequisite blocks\n"
"- CodeExecutionBlock output field naming\n"
"- Missing credentials configuration\n"
"- Node X coordinate spacing (800+ units apart)\n"
"- AI model default parameters\n"
"- Link static properties based on input schema\n"
"- Type mismatches (inserts conversion blocks)\n\n"
"Returns the fixed agent JSON plus a list of fixes applied. "
"After fixing, the agent is re-validated. If still invalid, "
"the remaining errors are included in the response."
)
@property
def requires_auth(self) -> bool:
return False
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"agent_json": {
"type": "object",
"description": (
"The agent JSON to fix. Must contain 'nodes' and 'links' arrays."
),
},
},
"required": ["agent_json"],
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
agent_json = kwargs.get("agent_json")
session_id = session.session_id if session else None
if not agent_json or not isinstance(agent_json, dict):
return ErrorResponse(
message="Please provide a valid agent JSON object.",
error="Missing or invalid agent_json parameter",
session_id=session_id,
)
nodes = agent_json.get("nodes", [])
if not nodes:
return ErrorResponse(
message="The agent JSON has no nodes. An agent needs at least one block.",
error="empty_agent",
session_id=session_id,
)
try:
blocks = get_blocks_as_dicts()
fixer = AgentFixer()
fixed_agent = fixer.apply_all_fixes(agent_json, blocks)
fixes_applied = fixer.get_fixes_applied()
except Exception as e:
logger.error(f"Fixer error: {e}", exc_info=True)
return ErrorResponse(
message=f"Auto-fix encountered an error: {str(e)}",
error="fix_exception",
session_id=session_id,
)
# Re-validate after fixing
try:
validator = AgentValidator()
is_valid, _ = validator.validate(fixed_agent, blocks)
remaining_errors = validator.errors if not is_valid else []
except Exception as e:
logger.warning(f"Post-fix validation error: {e}", exc_info=True)
remaining_errors = [f"Post-fix validation failed: {str(e)}"]
is_valid = False
if is_valid:
return FixResultResponse(
message=(
f"Applied {len(fixes_applied)} fix(es). "
"Agent graph is now valid!"
),
fixed_agent_json=fixed_agent,
fixes_applied=fixes_applied,
fix_count=len(fixes_applied),
valid_after_fix=True,
remaining_errors=[],
session_id=session_id,
)
return FixResultResponse(
message=(
f"Applied {len(fixes_applied)} fix(es), but "
f"{len(remaining_errors)} issue(s) remain. "
"Review the remaining errors and fix manually."
),
fixed_agent_json=fixed_agent,
fixes_applied=fixes_applied,
fix_count=len(fixes_applied),
valid_after_fix=False,
remaining_errors=remaining_errors,
session_id=session_id,
)

View File

@@ -0,0 +1,189 @@
"""Tests for FixAgentGraphTool."""
from unittest.mock import MagicMock, patch
import pytest
from backend.copilot.tools.fix_agent import FixAgentGraphTool
from backend.copilot.tools.models import ErrorResponse, FixResultResponse
from ._test_data import make_session
_TEST_USER_ID = "test-user-fix-agent"
@pytest.fixture
def tool():
return FixAgentGraphTool()
@pytest.fixture
def session():
return make_session(_TEST_USER_ID)
@pytest.mark.asyncio
async def test_missing_agent_json_returns_error(tool, session):
"""Missing agent_json returns ErrorResponse."""
result = await tool._execute(user_id=_TEST_USER_ID, session=session)
assert isinstance(result, ErrorResponse)
assert result.error is not None
assert "agent_json" in result.error.lower()
@pytest.mark.asyncio
async def test_empty_nodes_returns_error(tool, session):
"""Agent JSON with no nodes returns ErrorResponse."""
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json={"nodes": [], "links": []},
)
assert isinstance(result, ErrorResponse)
assert "no nodes" in result.message.lower()
@pytest.mark.asyncio
async def test_fix_and_validate_success(tool, session):
"""Fixer applies fixes and validator passes -> valid_after_fix=True."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "block-1",
"input_default": {},
"metadata": {"position": {"x": 0, "y": 0}},
}
],
"links": [],
}
fixed_agent = dict(agent_json) # Fixer returns the full agent dict
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=fixed_agent)
mock_fixer.get_fixes_applied.return_value = ["Fixed node UUID format"]
mock_validator = MagicMock()
mock_validator.validate.return_value = (True, None)
mock_validator.errors = []
with (
patch(
"backend.copilot.tools.fix_agent.get_blocks_as_dicts",
return_value=[],
),
patch(
"backend.copilot.tools.fix_agent.AgentFixer",
return_value=mock_fixer,
),
patch(
"backend.copilot.tools.fix_agent.AgentValidator",
return_value=mock_validator,
),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
)
assert isinstance(result, FixResultResponse)
assert result.valid_after_fix is True
assert result.fix_count == 1
assert result.fixes_applied == ["Fixed node UUID format"]
assert result.remaining_errors == []
@pytest.mark.asyncio
async def test_fix_with_remaining_errors(tool, session):
"""Fixer applies some fixes but validation still fails."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "block-1",
"input_default": {},
"metadata": {},
}
],
"links": [
{
"id": "link-1",
"source_id": "node-1",
"source_name": "output",
"sink_id": "node-2",
"sink_name": "input",
}
],
}
fixed_agent = dict(agent_json)
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=fixed_agent)
mock_fixer.get_fixes_applied.return_value = ["Fixed UUID"]
mock_validator = MagicMock()
mock_validator.validate.return_value = (
False,
"Link references non-existent node 'node-2'",
)
mock_validator.errors = ["Link references non-existent node 'node-2'"]
with (
patch(
"backend.copilot.tools.fix_agent.get_blocks_as_dicts",
return_value=[],
),
patch(
"backend.copilot.tools.fix_agent.AgentFixer",
return_value=mock_fixer,
),
patch(
"backend.copilot.tools.fix_agent.AgentValidator",
return_value=mock_validator,
),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
)
assert isinstance(result, FixResultResponse)
assert result.valid_after_fix is False
assert result.fix_count == 1
assert len(result.remaining_errors) == 1
@pytest.mark.asyncio
async def test_fixer_exception_returns_error(tool, session):
"""Fixer exception returns ErrorResponse."""
agent_json = {
"nodes": [{"id": "n1", "block_id": "b1", "input_default": {}, "metadata": {}}],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(side_effect=RuntimeError("fixer crashed"))
with (
patch(
"backend.copilot.tools.fix_agent.get_blocks_as_dicts",
return_value=[],
),
patch(
"backend.copilot.tools.fix_agent.AgentFixer",
return_value=mock_fixer,
),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
)
assert isinstance(result, ErrorResponse)
assert result.error is not None
assert "fix_exception" in result.error

View File

@@ -0,0 +1,84 @@
"""GetAgentBuildingGuideTool - Returns the complete agent building guide."""
import logging
from pathlib import Path
from typing import Any
from backend.copilot.model import ChatSession
from .base import BaseTool
from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
_GUIDE_CACHE: str | None = None
def _load_guide() -> str:
global _GUIDE_CACHE
if _GUIDE_CACHE is None:
guide_path = Path(__file__).parent.parent / "sdk" / "agent_generation_guide.md"
_GUIDE_CACHE = guide_path.read_text(encoding="utf-8")
return _GUIDE_CACHE
class AgentBuildingGuideResponse(ToolResponseBase):
"""Response containing the agent building guide."""
type: ResponseType = ResponseType.AGENT_BUILDER_GUIDE
content: str
class GetAgentBuildingGuideTool(BaseTool):
"""Returns the complete guide for building agent JSON graphs.
Covers block IDs, link structure, AgentInputBlock, AgentOutputBlock,
AgentExecutorBlock (sub-agent composition), and MCPToolBlock usage.
"""
@property
def name(self) -> str:
return "get_agent_building_guide"
@property
def description(self) -> str:
return (
"Returns the complete guide for building agent JSON graphs, including "
"block IDs, link structure, AgentInputBlock, AgentOutputBlock, "
"AgentExecutorBlock (for sub-agent composition), and MCPToolBlock usage. "
"Call this before generating agent JSON to ensure correct structure."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {},
"required": [],
}
@property
def requires_auth(self) -> bool:
return False
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id if session else None
try:
content = _load_guide()
return AgentBuildingGuideResponse(
message="Agent building guide loaded.",
content=content,
session_id=session_id,
)
except Exception as e:
logger.error("Failed to load agent building guide: %s", e)
return ErrorResponse(
message="Failed to load agent building guide.",
error=str(e),
session_id=session_id,
)

View File

@@ -0,0 +1,79 @@
"""GetMCPGuideTool - Returns the MCP tool usage guide."""
import logging
from pathlib import Path
from typing import Any
from backend.copilot.model import ChatSession
from .base import BaseTool
from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
_GUIDE_CACHE: str | None = None
def _load_guide() -> str:
global _GUIDE_CACHE
if _GUIDE_CACHE is None:
guide_path = Path(__file__).parent.parent / "sdk" / "mcp_tool_guide.md"
_GUIDE_CACHE = guide_path.read_text(encoding="utf-8")
return _GUIDE_CACHE
class MCPGuideResponse(ToolResponseBase):
"""Response containing the MCP tool guide."""
type: ResponseType = ResponseType.MCP_GUIDE
content: str
class GetMCPGuideTool(BaseTool):
"""Returns the MCP tool usage guide with known server URLs and auth details."""
@property
def name(self) -> str:
return "get_mcp_guide"
@property
def description(self) -> str:
return (
"Returns the MCP tool guide: known hosted server URLs (Notion, Linear, "
"Stripe, Intercom, Cloudflare, Atlassian) and authentication workflow. "
"Call before using run_mcp_tool if you need a server URL or auth info."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {},
"required": [],
}
@property
def requires_auth(self) -> bool:
return False
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id if session else None
try:
content = _load_guide()
return MCPGuideResponse(
message="MCP guide loaded.",
content=content,
session_id=session_id,
)
except Exception as e:
logger.error("Failed to load MCP guide: %s", e)
return ErrorResponse(
message="Failed to load MCP guide.",
error=str(e),
session_id=session_id,
)

View File

@@ -12,50 +12,51 @@ from backend.data.model import CredentialsMetaInput
class ResponseType(str, Enum):
"""Types of tool responses."""
# General
ERROR = "error"
NO_RESULTS = "no_results"
NEED_LOGIN = "need_login"
# Agent discovery & execution
AGENTS_FOUND = "agents_found"
AGENT_DETAILS = "agent_details"
SETUP_REQUIREMENTS = "setup_requirements"
INPUT_VALIDATION_ERROR = "input_validation_error"
EXECUTION_STARTED = "execution_started"
NEED_LOGIN = "need_login"
ERROR = "error"
NO_RESULTS = "no_results"
AGENT_OUTPUT = "agent_output"
UNDERSTANDING_UPDATED = "understanding_updated"
AGENT_PREVIEW = "agent_preview"
AGENT_SAVED = "agent_saved"
CLARIFICATION_NEEDED = "clarification_needed"
SUGGESTED_GOAL = "suggested_goal"
# Agent builder (create / edit / validate / fix)
AGENT_BUILDER_GUIDE = "agent_builder_guide"
AGENT_BUILDER_PREVIEW = "agent_builder_preview"
AGENT_BUILDER_SAVED = "agent_builder_saved"
AGENT_BUILDER_CLARIFICATION_NEEDED = "agent_builder_clarification_needed"
AGENT_BUILDER_VALIDATION_RESULT = "agent_builder_validation_result"
AGENT_BUILDER_FIX_RESULT = "agent_builder_fix_result"
# Block
BLOCK_LIST = "block_list"
BLOCK_DETAILS = "block_details"
BLOCK_OUTPUT = "block_output"
# MCP
MCP_GUIDE = "mcp_guide"
MCP_TOOLS_DISCOVERED = "mcp_tools_discovered"
MCP_TOOL_OUTPUT = "mcp_tool_output"
# Docs
DOC_SEARCH_RESULTS = "doc_search_results"
DOC_PAGE = "doc_page"
# Workspace response types
# Workspace files
WORKSPACE_FILE_LIST = "workspace_file_list"
WORKSPACE_FILE_CONTENT = "workspace_file_content"
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
# Long-running operation types
OPERATION_IN_PROGRESS = "operation_in_progress"
# Input validation
INPUT_VALIDATION_ERROR = "input_validation_error"
# Web fetch
WEB_FETCH = "web_fetch"
# Agent-browser multi-step automation (navigate, act, screenshot)
BROWSER_NAVIGATE = "browser_navigate"
BROWSER_ACT = "browser_act"
BROWSER_SCREENSHOT = "browser_screenshot"
# Code execution
BASH_EXEC = "bash_exec"
# Feature request types
FEATURE_REQUEST_SEARCH = "feature_request_search"
FEATURE_REQUEST_CREATED = "feature_request_created"
# Goal refinement
SUGGESTED_GOAL = "suggested_goal"
# MCP tool types
MCP_TOOLS_DISCOVERED = "mcp_tools_discovered"
MCP_TOOL_OUTPUT = "mcp_tool_output"
# Folder management types
# Folder management
FOLDER_CREATED = "folder_created"
FOLDER_LIST = "folder_list"
FOLDER_UPDATED = "folder_updated"
@@ -63,6 +64,21 @@ class ResponseType(str, Enum):
FOLDER_DELETED = "folder_deleted"
AGENTS_MOVED_TO_FOLDER = "agents_moved_to_folder"
# Browser automation
BROWSER_NAVIGATE = "browser_navigate"
BROWSER_ACT = "browser_act"
BROWSER_SCREENSHOT = "browser_screenshot"
# Code execution
BASH_EXEC = "bash_exec"
# Web
WEB_FETCH = "web_fetch"
# Feature requests
FEATURE_REQUEST_SEARCH = "feature_request_search"
FEATURE_REQUEST_CREATED = "feature_request_created"
# Base response model
class ToolResponseBase(BaseModel):
@@ -92,6 +108,15 @@ class AgentInfo(BaseModel):
has_external_trigger: bool | None = None
new_output: bool | None = None
graph_id: str | None = None
graph_version: int | None = None
input_schema: dict[str, Any] | None = Field(
default=None,
description="JSON Schema for the agent's inputs (for AgentExecutorBlock)",
)
output_schema: dict[str, Any] | None = Field(
default=None,
description="JSON Schema for the agent's outputs (for AgentExecutorBlock)",
)
inputs: dict[str, Any] | None = Field(
default=None,
description="Input schema for the agent, including field names, types, and defaults",
@@ -282,7 +307,7 @@ class ClarifyingQuestion(BaseModel):
class AgentPreviewResponse(ToolResponseBase):
"""Response for previewing a generated agent before saving."""
type: ResponseType = ResponseType.AGENT_PREVIEW
type: ResponseType = ResponseType.AGENT_BUILDER_PREVIEW
agent_json: dict[str, Any]
agent_name: str
description: str
@@ -293,7 +318,7 @@ class AgentPreviewResponse(ToolResponseBase):
class AgentSavedResponse(ToolResponseBase):
"""Response when an agent is saved to the library."""
type: ResponseType = ResponseType.AGENT_SAVED
type: ResponseType = ResponseType.AGENT_BUILDER_SAVED
agent_id: str
agent_name: str
library_agent_id: str
@@ -304,7 +329,7 @@ class AgentSavedResponse(ToolResponseBase):
class ClarificationNeededResponse(ToolResponseBase):
"""Response when the LLM needs more information from the user."""
type: ResponseType = ResponseType.CLARIFICATION_NEEDED
type: ResponseType = ResponseType.AGENT_BUILDER_CLARIFICATION_NEEDED
questions: list[ClarifyingQuestion] = Field(default_factory=list)
@@ -381,6 +406,10 @@ class BlockInfoSummary(BaseModel):
default_factory=dict,
description="Full JSON schema for block outputs",
)
static_output: bool = Field(
default=False,
description="Whether the block produces output without needing input",
)
required_inputs: list[BlockInputFieldInfo] = Field(
default_factory=list,
description="List of input fields for this block",
@@ -429,18 +458,6 @@ class BlockOutputResponse(ToolResponseBase):
success: bool = True
# Long-running operation models
class OperationInProgressResponse(ToolResponseBase):
"""Response when an operation is already in progress.
Returned for idempotency when the same tool_call_id is requested again
while the background task is still running.
"""
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
tool_call_id: str
class WebFetchResponse(ToolResponseBase):
"""Response for web_fetch tool."""
@@ -548,6 +565,29 @@ class BrowserScreenshotResponse(ToolResponseBase):
filename: str
# Agent generation tool response models
class ValidationResultResponse(ToolResponseBase):
"""Response for validate_agent_graph tool."""
type: ResponseType = ResponseType.AGENT_BUILDER_VALIDATION_RESULT
valid: bool
errors: list[str] = Field(default_factory=list)
error_count: int = 0
class FixResultResponse(ToolResponseBase):
"""Response for fix_agent_graph tool."""
type: ResponseType = ResponseType.AGENT_BUILDER_FIX_RESULT
fixed_agent_json: dict[str, Any]
fixes_applied: list[str] = Field(default_factory=list)
fix_count: int = 0
valid_after_fix: bool = False
remaining_errors: list[str] = Field(default_factory=list)
# Folder management models

View File

@@ -7,7 +7,7 @@ from typing import Any
from pydantic_core import PydanticUndefined
from backend.blocks import get_block
from backend.blocks import BlockType, get_block
from backend.blocks._base import AnyBlockSchema
from backend.copilot.model import ChatSession
from backend.data.db_accessors import workspace_db
@@ -83,7 +83,7 @@ class RunBlockTool(BaseTool):
),
},
},
"required": ["block_id", "input_data"],
"required": ["block_id", "block_name", "input_data"],
}
@property
@@ -149,11 +149,18 @@ class RunBlockTool(BaseTool):
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
):
# Provide actionable guidance for blocks with dedicated tools
if block.block_type == BlockType.MCP_TOOL:
hint = (
" Use the `run_mcp_tool` tool instead — it handles "
"MCP server discovery, authentication, and execution."
)
elif block.block_type == BlockType.AGENT:
hint = " Use the `run_agent` tool instead."
else:
hint = " This block is designed for use within graphs only."
return ErrorResponse(
message=(
f"Block '{block.name}' cannot be run directly in CoPilot. "
"This block is designed for use within graphs only."
),
message=f"Block '{block.name}' cannot be run directly.{hint}",
session_id=session_id,
)

View File

@@ -89,7 +89,7 @@ class TestRunBlockFiltering:
)
assert isinstance(response, ErrorResponse)
assert "cannot be run directly in CoPilot" in response.message
assert "cannot be run directly" in response.message
assert "designed for use within graphs only" in response.message
@pytest.mark.asyncio(loop_scope="session")
@@ -115,7 +115,7 @@ class TestRunBlockFiltering:
)
assert isinstance(response, ErrorResponse)
assert "cannot be run directly in CoPilot" in response.message
assert "cannot be run directly" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_non_excluded_block_passes_guard(self):
@@ -141,7 +141,7 @@ class TestRunBlockFiltering:
# Should NOT be an ErrorResponse about CoPilot exclusion
# (may be other errors like missing credentials, but not the exclusion guard)
if isinstance(response, ErrorResponse):
assert "cannot be run directly in CoPilot" not in response.message
assert "cannot be run directly" not in response.message
class TestRunBlockInputValidation:

View File

@@ -14,7 +14,7 @@ from backend.blocks.mcp.helpers import (
)
from backend.copilot.model import ChatSession
from backend.copilot.tools.utils import build_missing_credentials_from_field_info
from backend.util.request import HTTPClientError, validate_url
from backend.util.request import HTTPClientError, validate_url_host
from .base import BaseTool
from .models import (
@@ -53,15 +53,9 @@ class RunMCPToolTool(BaseTool):
def description(self) -> str:
return (
"Connect to an MCP (Model Context Protocol) server to discover and execute its tools. "
"Two-step workflow: (1) Call with just `server_url` to discover available tools. "
"(2) Call again with `server_url`, `tool_name`, and `tool_arguments` to execute. "
"Known hosted servers (use directly): Notion (https://mcp.notion.com/mcp), "
"Linear (https://mcp.linear.app/mcp), Stripe (https://mcp.stripe.com), "
"Intercom (https://mcp.intercom.com/mcp), Cloudflare (https://mcp.cloudflare.com/mcp), "
"Atlassian/Jira (https://mcp.atlassian.com/mcp). "
"For other services, search the MCP registry at https://registry.modelcontextprotocol.io/. "
"Authentication: If the server requires credentials, user will be prompted to complete the MCP credential setup flow."
"Once connected and user confirms, retry the same call immediately."
"Two-step: (1) call with server_url to list available tools, "
"(2) call again with server_url + tool_name + tool_arguments to execute. "
"Call get_mcp_guide for known server URLs and auth details."
)
@property
@@ -150,7 +144,7 @@ class RunMCPToolTool(BaseTool):
# Validate URL to prevent SSRF — blocks loopback and private IP ranges
try:
await validate_url(server_url, trusted_origins=[])
await validate_url_host(server_url)
except ValueError as e:
msg = str(e)
if "Unable to resolve" in msg or "No IP addresses" in msg:

View File

@@ -100,7 +100,7 @@ async def test_ssrf_blocked_url_returns_error():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url",
"backend.copilot.tools.run_mcp_tool.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
@@ -138,7 +138,7 @@ async def test_non_dict_tool_arguments_returns_error():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url",
"backend.copilot.tools.run_mcp_tool.validate_url_host",
new_callable=AsyncMock,
):
with patch(
@@ -171,7 +171,7 @@ async def test_discover_tools_returns_discovered_response():
mock_tools = _make_tool_list("fetch", "search")
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -208,7 +208,7 @@ async def test_discover_tools_with_credentials():
mock_tools = _make_tool_list("push_notification")
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -249,7 +249,7 @@ async def test_execute_tool_returns_output_response():
text_result = "# Example Domain\nThis domain is for examples."
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -285,7 +285,7 @@ async def test_execute_tool_parses_json_result():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -320,7 +320,7 @@ async def test_execute_tool_image_content():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -359,7 +359,7 @@ async def test_execute_tool_resource_content():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -399,7 +399,7 @@ async def test_execute_tool_multi_item_content():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -437,7 +437,7 @@ async def test_execute_tool_empty_content_returns_none():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -470,7 +470,7 @@ async def test_execute_tool_returns_error_on_tool_failure():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -512,7 +512,7 @@ async def test_auth_required_without_creds_returns_setup_requirements():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -555,7 +555,7 @@ async def test_auth_error_with_existing_creds_returns_error():
mock_creds.access_token = SecretStr("stale-token")
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -589,7 +589,7 @@ async def test_mcp_client_error_returns_error_response():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -621,7 +621,7 @@ async def test_unexpected_exception_returns_generic_error():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -719,7 +719,7 @@ async def test_credential_lookup_normalizes_trailing_slash():
url_with_slash = "https://mcp.example.com/mcp/"
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",

View File

@@ -0,0 +1,116 @@
"""ValidateAgentGraphTool - Validates agent JSON structure."""
import logging
from typing import Any
from backend.copilot.model import ChatSession
from .agent_generator.validation import AgentValidator, get_blocks_as_dicts
from .base import BaseTool
from .models import ErrorResponse, ToolResponseBase, ValidationResultResponse
logger = logging.getLogger(__name__)
class ValidateAgentGraphTool(BaseTool):
"""Tool for validating agent JSON graphs."""
@property
def name(self) -> str:
return "validate_agent_graph"
@property
def description(self) -> str:
return (
"Validate an agent JSON graph for correctness. Checks:\n"
"- All block_ids reference real blocks\n"
"- All links reference valid source/sink nodes and fields\n"
"- Required input fields are wired or have defaults\n"
"- Data types are compatible across links\n"
"- Nested sink links use correct notation\n"
"- Prompt templates use proper curly brace escaping\n"
"- AgentExecutorBlock configurations are valid\n\n"
"Call this after generating agent JSON to verify correctness. "
"If validation fails, either fix issues manually based on the error "
"descriptions, or call fix_agent_graph to auto-fix common problems."
)
@property
def requires_auth(self) -> bool:
return False
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"agent_json": {
"type": "object",
"description": (
"The agent JSON to validate. Must contain 'nodes' and 'links' arrays. "
"Each node needs: id (UUID), block_id, input_default, metadata. "
"Each link needs: id (UUID), source_id, source_name, sink_id, sink_name."
),
},
},
"required": ["agent_json"],
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
agent_json = kwargs.get("agent_json")
session_id = session.session_id if session else None
if not agent_json or not isinstance(agent_json, dict):
return ErrorResponse(
message="Please provide a valid agent JSON object.",
error="Missing or invalid agent_json parameter",
session_id=session_id,
)
nodes = agent_json.get("nodes", [])
if not nodes:
return ErrorResponse(
message="The agent JSON has no nodes. An agent needs at least one block.",
error="empty_agent",
session_id=session_id,
)
try:
blocks = get_blocks_as_dicts()
validator = AgentValidator()
is_valid, error_message = validator.validate(agent_json, blocks)
except Exception as e:
logger.error(f"Validation error: {e}", exc_info=True)
return ErrorResponse(
message=f"Validation encountered an error: {str(e)}",
error="validation_exception",
session_id=session_id,
)
if is_valid:
return ValidationResultResponse(
message="Agent graph is valid! No issues found.",
valid=True,
errors=[],
error_count=0,
session_id=session_id,
)
# Parse individual errors from the validator's error list
errors = validator.errors if hasattr(validator, "errors") else []
if not errors and error_message:
errors = [error_message]
return ValidationResultResponse(
message=f"Found {len(errors)} validation error(s). Fix them and re-validate.",
valid=False,
errors=errors,
error_count=len(errors),
session_id=session_id,
)

View File

@@ -0,0 +1,160 @@
"""Tests for ValidateAgentGraphTool."""
from unittest.mock import MagicMock, patch
import pytest
from backend.copilot.tools.models import ErrorResponse, ValidationResultResponse
from backend.copilot.tools.validate_agent import ValidateAgentGraphTool
from ._test_data import make_session
_TEST_USER_ID = "test-user-validate-agent"
@pytest.fixture
def tool():
return ValidateAgentGraphTool()
@pytest.fixture
def session():
return make_session(_TEST_USER_ID)
@pytest.mark.asyncio
async def test_missing_agent_json_returns_error(tool, session):
"""Missing agent_json returns ErrorResponse."""
result = await tool._execute(user_id=_TEST_USER_ID, session=session)
assert isinstance(result, ErrorResponse)
assert result.error is not None
assert "agent_json" in result.error.lower()
@pytest.mark.asyncio
async def test_empty_nodes_returns_error(tool, session):
"""Agent JSON with no nodes returns ErrorResponse."""
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json={"nodes": [], "links": []},
)
assert isinstance(result, ErrorResponse)
assert "no nodes" in result.message.lower()
@pytest.mark.asyncio
async def test_valid_agent_returns_success(tool, session):
"""Valid agent returns ValidationResultResponse with valid=True."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "block-1",
"input_default": {},
"metadata": {"position": {"x": 0, "y": 0}},
}
],
"links": [],
}
mock_validator = MagicMock()
mock_validator.validate.return_value = (True, None)
mock_validator.errors = []
with (
patch(
"backend.copilot.tools.validate_agent.get_blocks_as_dicts",
return_value=[],
),
patch(
"backend.copilot.tools.validate_agent.AgentValidator",
return_value=mock_validator,
),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
)
assert isinstance(result, ValidationResultResponse)
assert result.valid is True
assert result.error_count == 0
assert result.errors == []
@pytest.mark.asyncio
async def test_invalid_agent_returns_errors(tool, session):
"""Invalid agent returns ValidationResultResponse with errors."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "nonexistent-block",
"input_default": {},
"metadata": {},
}
],
"links": [],
}
mock_validator = MagicMock()
mock_validator.validate.return_value = (False, "Validation failed")
mock_validator.errors = [
"Block 'nonexistent-block' not found in registry",
"Missing required input field 'prompt'",
]
with (
patch(
"backend.copilot.tools.validate_agent.get_blocks_as_dicts",
return_value=[],
),
patch(
"backend.copilot.tools.validate_agent.AgentValidator",
return_value=mock_validator,
),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
)
assert isinstance(result, ValidationResultResponse)
assert result.valid is False
assert result.error_count == 2
assert len(result.errors) == 2
@pytest.mark.asyncio
async def test_validation_exception_returns_error(tool, session):
"""Validator exception returns ErrorResponse."""
agent_json = {
"nodes": [{"id": "n1", "block_id": "b1", "input_default": {}, "metadata": {}}],
"links": [],
}
mock_validator = MagicMock()
mock_validator.validate.side_effect = RuntimeError("unexpected error")
with (
patch(
"backend.copilot.tools.validate_agent.get_blocks_as_dicts",
return_value=[],
),
patch(
"backend.copilot.tools.validate_agent.AgentValidator",
return_value=mock_validator,
),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
)
assert isinstance(result, ErrorResponse)
assert result.error is not None
assert "validation_exception" in result.error

View File

@@ -7,8 +7,12 @@ from typing import Any, Optional
from pydantic import BaseModel
from backend.copilot.context import (
E2B_WORKDIR,
get_current_sandbox,
resolve_sandbox_path,
)
from backend.copilot.model import ChatSession
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
from backend.copilot.tools.sandbox import make_session_path
from backend.data.db_accessors import workspace_db
from backend.util.settings import Config
@@ -83,13 +87,11 @@ def _resolve_sandbox_path(
) -> str | ErrorResponse:
"""Normalize *path* to an absolute sandbox path under :data:`E2B_WORKDIR`.
Delegates to :func:`~backend.copilot.sdk.e2b_file_tools._resolve_remote`
Delegates to :func:`~backend.copilot.sdk.e2b_file_tools.resolve_sandbox_path`
and wraps any ``ValueError`` into an :class:`ErrorResponse`.
"""
from backend.copilot.sdk.e2b_file_tools import _resolve_remote
try:
return _resolve_remote(path)
return resolve_sandbox_path(path)
except ValueError:
return ErrorResponse(
message=f"{param_name} must be within {E2B_WORKDIR}",
@@ -99,7 +101,6 @@ def _resolve_sandbox_path(
async def _read_source_path(source_path: str, session_id: str) -> bytes | ErrorResponse:
"""Read *source_path* from E2B sandbox or local ephemeral directory."""
from backend.copilot.sdk.tool_adapter import get_current_sandbox
sandbox = get_current_sandbox()
if sandbox is not None:
@@ -143,7 +144,6 @@ async def _save_to_path(
Returns the resolved path on success, or an ``ErrorResponse`` on failure.
"""
from backend.copilot.sdk.tool_adapter import get_current_sandbox
sandbox = get_current_sandbox()
if sandbox is not None:

View File

@@ -37,6 +37,7 @@ async def initialize_blocks() -> None:
name=block.name,
inputSchema=json.dumps(block.input_schema.jsonschema()),
outputSchema=json.dumps(block.output_schema.jsonschema()),
description=block.description,
)
)
return
@@ -48,6 +49,7 @@ async def initialize_blocks() -> None:
or block.name != existing_block.name
or input_schema != existing_block.inputSchema
or output_schema != existing_block.outputSchema
or block.description != existing_block.description
):
await AgentBlock.prisma().update(
where={"id": existing_block.id},
@@ -56,6 +58,8 @@ async def initialize_blocks() -> None:
"name": block.name,
"inputSchema": input_schema,
"outputSchema": output_schema,
"description": block.description,
"optimizedDescription": None,
},
)
@@ -77,3 +81,45 @@ async def initialize_blocks() -> None:
f"Failed to sync {len(failed_blocks)} block(s) to database: "
f"{', '.join(failed_blocks)}. These blocks are still available in memory."
)
# Load optimized descriptions from DB onto block classes so that
# every get_block() instance automatically carries them.
try:
all_db_blocks = await AgentBlock.prisma().find_many(
where={"optimizedDescription": {"not": None}},
)
block_classes = get_blocks()
applied = 0
for db_block in all_db_blocks:
if db_block.optimizedDescription and db_block.id in block_classes:
block_classes[db_block.id]._optimized_description = (
db_block.optimizedDescription
)
applied += 1
if applied:
logger.info("Loaded %d optimized block descriptions", applied)
except Exception:
logger.error("Could not load optimized descriptions", exc_info=True)
async def get_blocks_needing_optimization() -> list[dict[str, str]]:
"""Return blocks that have a description but no optimized description yet."""
blocks = await AgentBlock.prisma().find_many(
where={
"description": {"not": None},
"optimizedDescription": None,
},
)
return [
{"id": b.id, "name": b.name, "description": b.description or ""} for b in blocks
]
async def update_block_optimized_description(
block_id: str, optimized_description: str
) -> None:
"""Store an LLM-optimized description for a block."""
await AgentBlock.prisma().update(
where={"id": block_id},
data={"optimizedDescription": optimized_description},
)

View File

@@ -39,6 +39,10 @@ from backend.data.analytics import (
get_marketplace_graphs_for_monitoring,
)
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
from backend.data.block import (
get_blocks_needing_optimization,
update_block_optimized_description,
)
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
create_graph_execution,
@@ -313,6 +317,10 @@ class DatabaseManager(AppService):
get_business_understanding = _(get_business_understanding)
upsert_business_understanding = _(upsert_business_understanding)
# ============ Block Descriptions ============ #
get_blocks_needing_optimization = _(get_blocks_needing_optimization)
update_block_optimized_description = _(update_block_optimized_description)
# ============ CoPilot Chat Sessions ============ #
get_chat_session = _(chat_db.get_chat_session)
create_chat_session = _(chat_db.create_chat_session)
@@ -379,6 +387,10 @@ class DatabaseManagerClient(AppServiceClient):
backfill_missing_embeddings = _(d.backfill_missing_embeddings)
cleanup_orphaned_embeddings = _(d.cleanup_orphaned_embeddings)
# Block Descriptions
get_blocks_needing_optimization = _(d.get_blocks_needing_optimization)
update_block_optimized_description = _(d.update_block_optimized_description)
class DatabaseManagerAsyncClient(AppServiceClient):
d = DatabaseManager
@@ -495,6 +507,9 @@ class DatabaseManagerAsyncClient(AppServiceClient):
get_business_understanding = d.get_business_understanding
upsert_business_understanding = d.upsert_business_understanding
# ============ Block Descriptions ============ #
get_blocks_needing_optimization = d.get_blocks_needing_optimization
# ============ CoPilot Chat Sessions ============ #
get_chat_session = d.get_chat_session
create_chat_session = d.create_chat_session

View File

@@ -1,601 +0,0 @@
import asyncio
import csv
import io
import logging
import re
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Literal, Optional
from uuid import uuid4
import prisma.enums
import prisma.models
import prisma.types
from pydantic import BaseModel, EmailStr, TypeAdapter, ValidationError
from backend.data.db import transaction
from backend.data.model import User
from backend.data.tally import get_business_understanding_input_from_tally
from backend.data.understanding import (
BusinessUnderstandingInput,
merge_business_understanding_data,
)
from backend.util.exceptions import (
NotAuthorizedError,
NotFoundError,
PreconditionFailed,
)
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
_tally_seed_tasks: set[asyncio.Task] = set()
_email_adapter = TypeAdapter(EmailStr)
MAX_BULK_INVITE_FILE_BYTES = 1024 * 1024
MAX_BULK_INVITE_ROWS = 500
class InvitedUserRecord(BaseModel):
id: str
email: str
status: prisma.enums.InvitedUserStatus
auth_user_id: Optional[str] = None
name: Optional[str] = None
tally_understanding: Optional[dict[str, Any]] = None
tally_status: prisma.enums.TallyComputationStatus
tally_computed_at: Optional[datetime] = None
tally_error: Optional[str] = None
created_at: datetime
updated_at: datetime
@classmethod
def from_db(cls, invited_user: "prisma.models.InvitedUser") -> "InvitedUserRecord":
payload = (
invited_user.tallyUnderstanding
if isinstance(invited_user.tallyUnderstanding, dict)
else None
)
return cls(
id=invited_user.id,
email=invited_user.email,
status=invited_user.status,
auth_user_id=invited_user.authUserId,
name=invited_user.name,
tally_understanding=payload,
tally_status=invited_user.tallyStatus,
tally_computed_at=invited_user.tallyComputedAt,
tally_error=invited_user.tallyError,
created_at=invited_user.createdAt,
updated_at=invited_user.updatedAt,
)
class BulkInvitedUserRowResult(BaseModel):
row_number: int
email: Optional[str] = None
name: Optional[str] = None
status: Literal["CREATED", "SKIPPED", "ERROR"]
message: str
invited_user: Optional[InvitedUserRecord] = None
class BulkInvitedUsersResult(BaseModel):
created_count: int
skipped_count: int
error_count: int
results: list[BulkInvitedUserRowResult]
@dataclass(frozen=True)
class _ParsedInviteRow:
row_number: int
email: str
name: Optional[str]
def normalize_email(email: str) -> str:
return email.strip().lower()
def _normalize_name(name: Optional[str]) -> Optional[str]:
if name is None:
return None
normalized = name.strip()
return normalized or None
def _default_profile_name(email: str, preferred_name: Optional[str]) -> str:
if preferred_name:
return preferred_name
local_part = email.split("@", 1)[0].strip()
return local_part or "user"
def _sanitize_username_base(email: str) -> str:
local_part = email.split("@", 1)[0].lower()
sanitized = re.sub(r"[^a-z0-9-]", "", local_part)
sanitized = sanitized.strip("-")
return sanitized[:40] or "user"
async def _generate_unique_profile_username(email: str, tx) -> str:
base = _sanitize_username_base(email)
for attempt in range(10):
candidate = base if attempt == 0 else f"{base}-{uuid4().hex[:6]}"
existing = await prisma.models.Profile.prisma(tx).find_unique(
where={"username": candidate}
)
if existing is None:
return candidate
raise RuntimeError(f"Unable to generate unique username for {email}")
async def _ensure_default_profile(
user_id: str,
email: str,
preferred_name: Optional[str],
tx,
) -> None:
existing_profile = await prisma.models.Profile.prisma(tx).find_unique(
where={"userId": user_id}
)
if existing_profile is not None:
return
username = await _generate_unique_profile_username(email, tx)
await prisma.models.Profile.prisma(tx).create(
data=prisma.types.ProfileCreateInput(
userId=user_id,
name=_default_profile_name(email, preferred_name),
username=username,
description="I'm new here",
links=[],
avatarUrl="",
)
)
async def _ensure_default_onboarding(user_id: str, tx) -> None:
await prisma.models.UserOnboarding.prisma(tx).upsert(
where={"userId": user_id},
data={
"create": prisma.types.UserOnboardingCreateInput(userId=user_id),
"update": {},
},
)
async def _apply_tally_understanding(
user_id: str,
invited_user: "prisma.models.InvitedUser",
tx,
) -> None:
if not isinstance(invited_user.tallyUnderstanding, dict):
return
input_data = BusinessUnderstandingInput.model_validate(
invited_user.tallyUnderstanding
)
payload = merge_business_understanding_data({}, input_data)
await prisma.models.CoPilotUnderstanding.prisma(tx).upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "data": SafeJson(payload)},
"update": {"data": SafeJson(payload)},
},
)
async def list_invited_users() -> list[InvitedUserRecord]:
invited_users = await prisma.models.InvitedUser.prisma().find_many(
order={"createdAt": "desc"}
)
return [InvitedUserRecord.from_db(invited_user) for invited_user in invited_users]
async def create_invited_user(
email: str, name: Optional[str] = None
) -> InvitedUserRecord:
normalized_email = normalize_email(email)
normalized_name = _normalize_name(name)
existing_user = await prisma.models.User.prisma().find_unique(
where={"email": normalized_email}
)
if existing_user is not None:
raise PreconditionFailed("An active user with this email already exists")
existing_invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"email": normalized_email}
)
if existing_invited_user is not None:
raise PreconditionFailed("An invited user with this email already exists")
invited_user = await prisma.models.InvitedUser.prisma().create(
data={
"email": normalized_email,
"name": normalized_name,
"status": prisma.enums.InvitedUserStatus.INVITED,
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
}
)
schedule_invited_user_tally_precompute(invited_user.id)
return InvitedUserRecord.from_db(invited_user)
async def revoke_invited_user(invited_user_id: str) -> InvitedUserRecord:
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"id": invited_user_id}
)
if invited_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
if invited_user.status == prisma.enums.InvitedUserStatus.CLAIMED:
raise PreconditionFailed("Claimed invited users cannot be revoked")
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
return InvitedUserRecord.from_db(invited_user)
revoked_user = await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={"status": prisma.enums.InvitedUserStatus.REVOKED},
)
if revoked_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
return InvitedUserRecord.from_db(revoked_user)
async def retry_invited_user_tally(invited_user_id: str) -> InvitedUserRecord:
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"id": invited_user_id}
)
if invited_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
raise PreconditionFailed("Revoked invited users cannot retry Tally seeding")
refreshed_user = await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyUnderstanding": None,
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
"tallyComputedAt": None,
"tallyError": None,
},
)
if refreshed_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
schedule_invited_user_tally_precompute(invited_user_id)
return InvitedUserRecord.from_db(refreshed_user)
def _decode_bulk_invite_file(content: bytes) -> str:
if len(content) > MAX_BULK_INVITE_FILE_BYTES:
raise ValueError("Invite file exceeds the maximum size of 1 MB")
try:
return content.decode("utf-8-sig")
except UnicodeDecodeError as exc:
raise ValueError("Invite file must be UTF-8 encoded") from exc
def _parse_bulk_invite_csv(text: str) -> list[_ParsedInviteRow]:
indexed_rows: list[tuple[int, list[str]]] = []
for row_number, row in enumerate(csv.reader(io.StringIO(text)), start=1):
normalized_row = [cell.strip() for cell in row]
if any(normalized_row):
indexed_rows.append((row_number, normalized_row))
if not indexed_rows:
return []
header = [cell.lower() for cell in indexed_rows[0][1]]
has_header = "email" in header
email_index = header.index("email") if has_header else 0
name_index = header.index("name") if has_header and "name" in header else 1
data_rows = indexed_rows[1:] if has_header else indexed_rows
parsed_rows: list[_ParsedInviteRow] = []
for row_number, row in data_rows:
email = row[email_index].strip() if len(row) > email_index else ""
name = row[name_index].strip() if len(row) > name_index else ""
parsed_rows.append(
_ParsedInviteRow(
row_number=row_number,
email=email,
name=name or None,
)
)
return parsed_rows
def _parse_bulk_invite_text(text: str) -> list[_ParsedInviteRow]:
parsed_rows: list[_ParsedInviteRow] = []
for row_number, raw_line in enumerate(text.splitlines(), start=1):
line = raw_line.strip()
if not line or line.startswith("#"):
continue
parsed_rows.append(
_ParsedInviteRow(
row_number=row_number,
email=line,
name=None,
)
)
return parsed_rows
def _parse_bulk_invite_file(
filename: Optional[str],
content: bytes,
) -> list[_ParsedInviteRow]:
text = _decode_bulk_invite_file(content)
file_name = filename.lower() if filename else ""
parsed_rows = (
_parse_bulk_invite_csv(text)
if file_name.endswith(".csv")
else _parse_bulk_invite_text(text)
)
if not parsed_rows:
raise ValueError("Invite file did not contain any emails")
if len(parsed_rows) > MAX_BULK_INVITE_ROWS:
raise ValueError(
f"Invite file contains too many rows. Maximum supported rows: {MAX_BULK_INVITE_ROWS}"
)
return parsed_rows
async def bulk_create_invited_users_from_file(
filename: Optional[str],
content: bytes,
) -> BulkInvitedUsersResult:
parsed_rows = _parse_bulk_invite_file(filename, content)
created_count = 0
skipped_count = 0
error_count = 0
results: list[BulkInvitedUserRowResult] = []
seen_emails: set[str] = set()
for row in parsed_rows:
row_name = _normalize_name(row.name)
try:
validated_email = _email_adapter.validate_python(row.email)
except ValidationError:
error_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=row.email or None,
name=row_name,
status="ERROR",
message="Invalid email address",
)
)
continue
normalized_email = normalize_email(str(validated_email))
if normalized_email in seen_emails:
skipped_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="SKIPPED",
message="Duplicate email in upload file",
)
)
continue
seen_emails.add(normalized_email)
try:
invited_user = await create_invited_user(normalized_email, row_name)
except PreconditionFailed as exc:
skipped_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="SKIPPED",
message=str(exc),
)
)
except Exception:
logger.exception(
"Failed to create bulk invite for %s from row %s",
normalized_email,
row.row_number,
)
error_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="ERROR",
message="Unexpected error creating invite",
)
)
else:
created_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="CREATED",
message="Invite created",
invited_user=invited_user,
)
)
return BulkInvitedUsersResult(
created_count=created_count,
skipped_count=skipped_count,
error_count=error_count,
results=results,
)
async def _compute_invited_user_tally_seed(invited_user_id: str) -> None:
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"id": invited_user_id}
)
if invited_user is None:
return
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
return
await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyStatus": prisma.enums.TallyComputationStatus.RUNNING,
"tallyError": None,
},
)
try:
input_data = await get_business_understanding_input_from_tally(
invited_user.email,
require_api_key=True,
)
payload = (
SafeJson(input_data.model_dump(exclude_none=True))
if input_data is not None
else None
)
await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyUnderstanding": payload,
"tallyStatus": prisma.enums.TallyComputationStatus.READY,
"tallyComputedAt": datetime.now(timezone.utc),
"tallyError": None,
},
)
except Exception as exc:
logger.exception(
"Failed to compute Tally understanding for invited user %s",
invited_user_id,
)
await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyStatus": prisma.enums.TallyComputationStatus.FAILED,
"tallyError": str(exc),
},
)
def schedule_invited_user_tally_precompute(invited_user_id: str) -> None:
task = asyncio.create_task(_compute_invited_user_tally_seed(invited_user_id))
_tally_seed_tasks.add(task)
task.add_done_callback(_tally_seed_tasks.discard)
async def get_or_activate_user(user_data: dict) -> User:
auth_user_id = user_data.get("sub")
if not auth_user_id:
raise NotAuthorizedError("User ID not found in token")
auth_email = user_data.get("email")
if not auth_email:
raise NotAuthorizedError("Email not found in token")
normalized_email = normalize_email(auth_email)
user_metadata = user_data.get("user_metadata")
metadata_name = (
user_metadata.get("name") if isinstance(user_metadata, dict) else None
)
existing_user = await prisma.models.User.prisma().find_unique(
where={"id": auth_user_id}
)
if existing_user is not None:
return User.from_db(existing_user)
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"email": normalized_email}
)
if invited_user is None:
raise NotAuthorizedError("Your email is not allowed to access the platform")
if invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
raise NotAuthorizedError("Your invitation is no longer active")
async with transaction() as tx:
current_user = await prisma.models.User.prisma(tx).find_unique(
where={"id": auth_user_id}
)
if current_user is not None:
return User.from_db(current_user)
current_invited_user = await prisma.models.InvitedUser.prisma(tx).find_unique(
where={"email": normalized_email}
)
if current_invited_user is None:
raise NotAuthorizedError("Your email is not allowed to access the platform")
if current_invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
raise NotAuthorizedError("Your invitation is no longer active")
if current_invited_user.authUserId not in (None, auth_user_id):
raise NotAuthorizedError("Your invitation has already been claimed")
preferred_name = current_invited_user.name or _normalize_name(metadata_name)
await prisma.models.User.prisma(tx).create(
data=prisma.types.UserCreateInput(
id=auth_user_id,
email=normalized_email,
name=preferred_name,
)
)
await prisma.models.InvitedUser.prisma(tx).update(
where={"id": current_invited_user.id},
data={
"status": prisma.enums.InvitedUserStatus.CLAIMED,
"authUserId": auth_user_id,
},
)
await _ensure_default_profile(
auth_user_id,
normalized_email,
preferred_name,
tx,
)
await _ensure_default_onboarding(auth_user_id, tx)
await _apply_tally_understanding(auth_user_id, current_invited_user, tx)
from backend.data.user import get_user_by_email, get_user_by_id
get_user_by_id.cache_delete(auth_user_id)
get_user_by_email.cache_delete(normalized_email)
activated_user = await prisma.models.User.prisma().find_unique(
where={"id": auth_user_id}
)
if activated_user is None:
raise RuntimeError(
f"Activated user {auth_user_id} was not found after creation"
)
return User.from_db(activated_user)

View File

@@ -1,297 +0,0 @@
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock
import prisma.enums
import pytest
import pytest_mock
from backend.util.exceptions import NotAuthorizedError, PreconditionFailed
from .invited_user import (
bulk_create_invited_users_from_file,
create_invited_user,
get_or_activate_user,
retry_invited_user_tally,
)
def _invited_user_db_record(
*,
status: prisma.enums.InvitedUserStatus = prisma.enums.InvitedUserStatus.INVITED,
tally_understanding: dict | None = None,
):
now = datetime.now(timezone.utc)
return SimpleNamespace(
id="invite-1",
email="invited@example.com",
status=status,
authUserId=None,
name="Invited User",
tallyUnderstanding=tally_understanding,
tallyStatus=prisma.enums.TallyComputationStatus.PENDING,
tallyComputedAt=None,
tallyError=None,
createdAt=now,
updatedAt=now,
)
def _user_db_record():
now = datetime.now(timezone.utc)
return SimpleNamespace(
id="auth-user-1",
email="invited@example.com",
emailVerified=True,
name="Invited User",
createdAt=now,
updatedAt=now,
metadata={},
integrations="",
stripeCustomerId=None,
topUpConfig=None,
maxEmailsPerDay=3,
notifyOnAgentRun=True,
notifyOnZeroBalance=True,
notifyOnLowBalance=True,
notifyOnBlockExecutionFailed=True,
notifyOnContinuousAgentError=True,
notifyOnDailySummary=True,
notifyOnWeeklySummary=True,
notifyOnMonthlySummary=True,
notifyOnAgentApproved=True,
notifyOnAgentRejected=True,
timezone="not-set",
)
@pytest.mark.asyncio
async def test_create_invited_user_rejects_existing_active_user(
mocker: pytest_mock.MockerFixture,
) -> None:
user_repo = Mock()
user_repo.find_unique = AsyncMock(return_value=_user_db_record())
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock()
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
with pytest.raises(PreconditionFailed):
await create_invited_user("Invited@example.com")
@pytest.mark.asyncio
async def test_create_invited_user_schedules_tally_seed(
mocker: pytest_mock.MockerFixture,
) -> None:
user_repo = Mock()
user_repo.find_unique = AsyncMock(return_value=None)
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock(return_value=None)
invited_user_repo.create = AsyncMock(return_value=_invited_user_db_record())
schedule = mocker.patch(
"backend.data.invited_user.schedule_invited_user_tally_precompute"
)
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
invited_user = await create_invited_user("Invited@example.com", "Invited User")
assert invited_user.email == "invited@example.com"
invited_user_repo.create.assert_awaited_once()
schedule.assert_called_once_with("invite-1")
@pytest.mark.asyncio
async def test_retry_invited_user_tally_resets_state_and_schedules(
mocker: pytest_mock.MockerFixture,
) -> None:
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock(return_value=_invited_user_db_record())
invited_user_repo.update = AsyncMock(return_value=_invited_user_db_record())
schedule = mocker.patch(
"backend.data.invited_user.schedule_invited_user_tally_precompute"
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
invited_user = await retry_invited_user_tally("invite-1")
assert invited_user.id == "invite-1"
invited_user_repo.update.assert_awaited_once()
schedule.assert_called_once_with("invite-1")
@pytest.mark.asyncio
async def test_get_or_activate_user_requires_invite(
mocker: pytest_mock.MockerFixture,
) -> None:
user_repo = Mock()
user_repo.find_unique = AsyncMock(return_value=None)
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock(return_value=None)
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
with pytest.raises(NotAuthorizedError):
await get_or_activate_user(
{"sub": "auth-user-1", "email": "invited@example.com"}
)
@pytest.mark.asyncio
async def test_get_or_activate_user_creates_user_from_invite(
mocker: pytest_mock.MockerFixture,
) -> None:
tx = object()
invited_user = _invited_user_db_record(
tally_understanding={"user_name": "Invited User", "industry": "Automation"}
)
created_user = _user_db_record()
outside_user_repo = Mock()
outside_user_repo.find_unique = AsyncMock(side_effect=[None, created_user])
inside_user_repo = Mock()
inside_user_repo.find_unique = AsyncMock(return_value=None)
inside_user_repo.create = AsyncMock(return_value=created_user)
outside_invited_repo = Mock()
outside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
inside_invited_repo = Mock()
inside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
inside_invited_repo.update = AsyncMock(return_value=invited_user)
def user_prisma(client=None):
return inside_user_repo if client is tx else outside_user_repo
def invited_user_prisma(client=None):
return inside_invited_repo if client is tx else outside_invited_repo
@asynccontextmanager
async def fake_transaction():
yield tx
ensure_profile = mocker.patch(
"backend.data.invited_user._ensure_default_profile",
AsyncMock(),
)
ensure_onboarding = mocker.patch(
"backend.data.invited_user._ensure_default_onboarding",
AsyncMock(),
)
apply_tally = mocker.patch(
"backend.data.invited_user._apply_tally_understanding",
AsyncMock(),
)
mocker.patch("backend.data.invited_user.transaction", fake_transaction)
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", side_effect=user_prisma
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
side_effect=invited_user_prisma,
)
mocker.patch("backend.data.user.get_user_by_id.cache_delete", Mock())
mocker.patch("backend.data.user.get_user_by_email.cache_delete", Mock())
user = await get_or_activate_user(
{
"sub": "auth-user-1",
"email": "Invited@example.com",
"user_metadata": {"name": "Invited User"},
}
)
assert user.id == "auth-user-1"
inside_user_repo.create.assert_awaited_once()
inside_invited_repo.update.assert_awaited_once()
ensure_profile.assert_awaited_once()
ensure_onboarding.assert_awaited_once_with("auth-user-1", tx)
apply_tally.assert_awaited_once_with("auth-user-1", invited_user, tx)
@pytest.mark.asyncio
async def test_bulk_create_invited_users_from_text_file(
mocker: pytest_mock.MockerFixture,
) -> None:
create_invited = mocker.patch(
"backend.data.invited_user.create_invited_user",
AsyncMock(
side_effect=[
_invited_user_db_record(),
_invited_user_db_record(),
]
),
)
result = await bulk_create_invited_users_from_file(
"invites.txt",
b"Invited@example.com\nsecond@example.com\n",
)
assert result.created_count == 2
assert result.skipped_count == 0
assert result.error_count == 0
assert [row.status for row in result.results] == ["CREATED", "CREATED"]
assert create_invited.await_count == 2
@pytest.mark.asyncio
async def test_bulk_create_invited_users_handles_csv_duplicates_and_invalid_rows(
mocker: pytest_mock.MockerFixture,
) -> None:
create_invited = mocker.patch(
"backend.data.invited_user.create_invited_user",
AsyncMock(
side_effect=[
_invited_user_db_record(),
PreconditionFailed("An invited user with this email already exists"),
]
),
)
result = await bulk_create_invited_users_from_file(
"invites.csv",
(
"email,name\n"
"valid@example.com,Valid User\n"
"not-an-email,Bad Row\n"
"valid@example.com,Duplicate In File\n"
"existing@example.com,Existing User\n"
).encode("utf-8"),
)
assert result.created_count == 1
assert result.skipped_count == 2
assert result.error_count == 1
assert [row.status for row in result.results] == [
"CREATED",
"ERROR",
"SKIPPED",
"SKIPPED",
]
assert create_invited.await_count == 2

View File

@@ -0,0 +1,31 @@
"""LLM Registry - Dynamic model management system."""
from .model import ModelMetadata
from .registry import (
RegistryModel,
RegistryModelCost,
RegistryModelCreator,
get_all_model_slugs_for_validation,
get_all_models,
get_default_model_slug,
get_enabled_models,
get_model,
get_schema_options,
refresh_llm_registry,
)
__all__ = [
# Models
"ModelMetadata",
"RegistryModel",
"RegistryModelCost",
"RegistryModelCreator",
# Functions
"refresh_llm_registry",
"get_model",
"get_all_models",
"get_enabled_models",
"get_schema_options",
"get_default_model_slug",
"get_all_model_slugs_for_validation",
]

View File

@@ -0,0 +1,9 @@
"""Type definitions for LLM model metadata.
Re-exports ModelMetadata from blocks.llm to avoid type collision.
In PR #5 (block integration), this will become the canonical location.
"""
from backend.blocks.llm import ModelMetadata
__all__ = ["ModelMetadata"]

View File

@@ -0,0 +1,240 @@
"""Core LLM registry implementation for managing models dynamically."""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any
import prisma.models
from backend.data.llm_registry.model import ModelMetadata
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class RegistryModelCost:
"""Cost configuration for an LLM model."""
unit: str # "RUN" or "TOKENS"
credit_cost: int
credential_provider: str
credential_id: str | None
credential_type: str | None
currency: str | None
metadata: dict[str, Any]
@dataclass(frozen=True)
class RegistryModelCreator:
"""Creator information for an LLM model."""
id: str
name: str
display_name: str
description: str | None
website_url: str | None
logo_url: str | None
@dataclass(frozen=True)
class RegistryModel:
"""Represents a model in the LLM registry."""
slug: str
display_name: str
description: str | None
metadata: ModelMetadata
capabilities: dict[str, Any]
extra_metadata: dict[str, Any]
provider_display_name: str
is_enabled: bool
is_recommended: bool = False
costs: tuple[RegistryModelCost, ...] = field(default_factory=tuple)
creator: RegistryModelCreator | None = None
# In-memory cache (will be replaced with Redis in PR #6)
_dynamic_models: dict[str, RegistryModel] = {}
_schema_options: list[dict[str, str]] = []
_lock = asyncio.Lock()
async def refresh_llm_registry() -> None:
"""
Refresh the LLM registry from the database.
Fetches all models with their costs, providers, and creators,
then updates the in-memory cache.
"""
async with _lock:
try:
records = await prisma.models.LlmModel.prisma().find_many(
include={
"Provider": True,
"Costs": True,
"Creator": True,
}
)
logger.info(f"Fetched {len(records)} LLM models from database")
# Build model instances
new_models: dict[str, RegistryModel] = {}
for record in records:
# Parse costs
costs = tuple(
RegistryModelCost(
unit=str(cost.unit), # Convert enum to string
credit_cost=cost.creditCost,
credential_provider=cost.credentialProvider,
credential_id=cost.credentialId,
credential_type=cost.credentialType,
currency=cost.currency,
metadata=dict(cost.metadata or {}),
)
for cost in (record.Costs or [])
)
# Parse creator
creator = None
if record.Creator:
creator = RegistryModelCreator(
id=record.Creator.id,
name=record.Creator.name,
display_name=record.Creator.displayName,
description=record.Creator.description,
website_url=record.Creator.websiteUrl,
logo_url=record.Creator.logoUrl,
)
# Parse capabilities
capabilities = dict(record.capabilities or {})
# Build metadata from record
# Warn if Provider relation is missing (indicates data corruption)
if not record.Provider:
logger.warning(
f"LlmModel {record.slug} has no Provider despite NOT NULL FK - "
f"falling back to providerId {record.providerId}"
)
provider_name = (
record.Provider.name if record.Provider else record.providerId
)
provider_display = (
record.Provider.displayName
if record.Provider
else record.providerId
)
# Extract creator name (fallback to "Unknown" if no creator)
creator_name = (
record.Creator.displayName if record.Creator else "Unknown"
)
# Price tier defaults to 1 if not set
price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1
metadata = ModelMetadata(
provider=provider_name,
context_window=record.contextWindow,
max_output_tokens=(
record.maxOutputTokens
if record.maxOutputTokens is not None
else record.contextWindow
),
display_name=record.displayName,
provider_name=provider_display,
creator_name=creator_name,
price_tier=price_tier,
)
# Create model instance
model = RegistryModel(
slug=record.slug,
display_name=record.displayName,
description=record.description,
metadata=metadata,
capabilities=capabilities,
extra_metadata=dict(record.metadata or {}),
provider_display_name=provider_display,
is_enabled=record.isEnabled,
is_recommended=record.isRecommended,
costs=costs,
creator=creator,
)
new_models[record.slug] = model
# Atomic swap
global _dynamic_models, _schema_options
_dynamic_models = new_models
_schema_options = _build_schema_options()
logger.info(
f"LLM registry refreshed: {len(_dynamic_models)} models, "
f"{len(_schema_options)} schema options"
)
except Exception as e:
logger.error(f"Failed to refresh LLM registry: {e}", exc_info=True)
raise
def _build_schema_options() -> list[dict[str, str]]:
"""Build schema options for model selection dropdown. Only includes enabled models."""
return [
{
"label": model.display_name,
"value": model.slug,
"group": model.metadata.provider,
"description": model.description or "",
}
for model in sorted(
_dynamic_models.values(), key=lambda m: m.display_name.lower()
)
if model.is_enabled
]
def get_model(slug: str) -> RegistryModel | None:
"""Get a model by slug from the registry."""
return _dynamic_models.get(slug)
def get_all_models() -> list[RegistryModel]:
"""Get all models from the registry (including disabled)."""
return list(_dynamic_models.values())
def get_enabled_models() -> list[RegistryModel]:
"""Get only enabled models from the registry."""
return [model for model in _dynamic_models.values() if model.is_enabled]
def get_schema_options() -> list[dict[str, str]]:
"""Get schema options for model selection dropdown (enabled models only)."""
return _schema_options
def get_default_model_slug() -> str | None:
"""Get the default model slug (first recommended, or first enabled)."""
# Sort once and use next() to short-circuit on first match
models = sorted(_dynamic_models.values(), key=lambda m: m.display_name)
# Prefer recommended models
recommended = next(
(m.slug for m in models if m.is_recommended and m.is_enabled), None
)
if recommended:
return recommended
# Fallback to first enabled model
return next((m.slug for m in models if m.is_enabled), None)
def get_all_model_slugs_for_validation() -> list[str]:
"""
Get all model slugs for validation (enables migrate_llm_models to work).
Returns slugs for enabled models only.
"""
return [model.slug for model in _dynamic_models.values() if model.is_enabled]

View File

@@ -14,6 +14,7 @@ from backend.data.understanding import (
get_business_understanding,
upsert_business_understanding,
)
from backend.util.clients import OPENROUTER_BASE_URL
from backend.util.request import Requests
from backend.util.settings import Settings
@@ -347,7 +348,7 @@ async def extract_business_understanding(
"""
settings = Settings()
api_key = settings.secrets.open_router_api_key
client = AsyncOpenAI(api_key=api_key, base_url="https://openrouter.ai/api/v1")
client = AsyncOpenAI(api_key=api_key, base_url=OPENROUTER_BASE_URL)
try:
response = await asyncio.wait_for(
@@ -380,35 +381,6 @@ async def extract_business_understanding(
return BusinessUnderstandingInput(**cleaned)
async def get_business_understanding_input_from_tally(
email: str,
*,
require_api_key: bool = False,
) -> Optional[BusinessUnderstandingInput]:
settings = Settings()
if not settings.secrets.tally_api_key:
if require_api_key:
raise RuntimeError("Tally API key is not configured")
logger.debug("Tally: no API key configured, skipping")
return None
masked = _mask_email(email)
result = await find_submission_by_email(TALLY_FORM_ID, email)
if result is None:
logger.debug(f"Tally: no submission found for {masked}")
return None
submission, questions = result
logger.info(f"Tally: found submission for {masked}, extracting understanding")
formatted = format_submission_for_llm(submission, questions)
if not formatted.strip():
logger.warning("Tally: formatted submission was empty, skipping")
return None
return await extract_business_understanding(formatted)
async def populate_understanding_from_tally(user_id: str, email: str) -> None:
"""Main orchestrator: check Tally for a matching submission and populate understanding.
@@ -423,10 +395,30 @@ async def populate_understanding_from_tally(user_id: str, email: str) -> None:
)
return
understanding_input = await get_business_understanding_input_from_tally(email)
if understanding_input is None:
# Check API key is configured
settings = Settings()
if not settings.secrets.tally_api_key:
logger.debug("Tally: no API key configured, skipping")
return
# Look up submission by email
masked = _mask_email(email)
result = await find_submission_by_email(TALLY_FORM_ID, email)
if result is None:
logger.debug(f"Tally: no submission found for {masked}")
return
submission, questions = result
logger.info(f"Tally: found submission for {masked}, extracting understanding")
# Format and extract
formatted = format_submission_for_llm(submission, questions)
if not formatted.strip():
logger.warning("Tally: formatted submission was empty, skipping")
return
understanding_input = await extract_business_understanding(formatted)
# Upsert into database
await upsert_business_understanding(user_id, understanding_input)
logger.info(f"Tally: successfully populated understanding for user {user_id}")

View File

@@ -166,56 +166,6 @@ def _merge_lists(existing: list | None, new: list | None) -> list | None:
return merged
def merge_business_understanding_data(
existing_data: dict[str, Any],
input_data: BusinessUnderstandingInput,
) -> dict[str, Any]:
merged_data = dict(existing_data)
merged_business: dict[str, Any] = {}
if isinstance(merged_data.get("business"), dict):
merged_business = dict(merged_data["business"])
business_string_fields = [
"job_title",
"business_name",
"industry",
"business_size",
"user_role",
"additional_notes",
]
business_list_fields = [
"key_workflows",
"daily_activities",
"pain_points",
"bottlenecks",
"manual_tasks",
"automation_goals",
"current_software",
"existing_automation",
]
if input_data.user_name is not None:
merged_data["name"] = input_data.user_name
for field in business_string_fields:
value = getattr(input_data, field)
if value is not None:
merged_business[field] = value
for field in business_list_fields:
value = getattr(input_data, field)
if value is not None:
existing_list = _json_to_list(merged_business.get(field))
merged_list = _merge_lists(existing_list, value)
merged_business[field] = merged_list
merged_business["version"] = 1
merged_data["business"] = merged_business
return merged_data
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
"""Get business understanding from Redis cache."""
try:
@@ -295,18 +245,63 @@ async def upsert_business_understanding(
where={"userId": user_id}
)
# Get existing data structure or start fresh
existing_data: dict[str, Any] = {}
if existing and isinstance(existing.data, dict):
existing_data = dict(existing.data)
merged_data = merge_business_understanding_data(existing_data, input_data)
existing_business: dict[str, Any] = {}
if isinstance(existing_data.get("business"), dict):
existing_business = dict(existing_data["business"])
# Business fields (stored inside business object)
business_string_fields = [
"job_title",
"business_name",
"industry",
"business_size",
"user_role",
"additional_notes",
]
business_list_fields = [
"key_workflows",
"daily_activities",
"pain_points",
"bottlenecks",
"manual_tasks",
"automation_goals",
"current_software",
"existing_automation",
]
# Handle top-level name field
if input_data.user_name is not None:
existing_data["name"] = input_data.user_name
# Business string fields - overwrite if provided
for field in business_string_fields:
value = getattr(input_data, field)
if value is not None:
existing_business[field] = value
# Business list fields - merge with existing
for field in business_list_fields:
value = getattr(input_data, field)
if value is not None:
existing_list = _json_to_list(existing_business.get(field))
merged = _merge_lists(existing_list, value)
existing_business[field] = merged
# Set version and nest business data
existing_business["version"] = 1
existing_data["business"] = existing_business
# Upsert with the merged data
record = await CoPilotUnderstanding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "data": SafeJson(merged_data)},
"update": {"data": SafeJson(merged_data)},
"create": {"userId": user_id, "data": SafeJson(existing_data)},
"update": {"data": SafeJson(existing_data)},
},
)

View File

@@ -24,6 +24,7 @@ from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import MetaData, create_engine
from backend.copilot.optimize_blocks import optimize_block_descriptions
from backend.data.execution import GraphExecutionWithNodes
from backend.data.model import CredentialsMetaInput, GraphInput
from backend.executor import utils as execution_utils
@@ -604,6 +605,19 @@ class Scheduler(AppService):
jobstore=Jobstores.EXECUTION.value,
)
# Block Description Optimization - Every 24 hours
# Generates concise LLM-optimized block descriptions for
# agent generation. Only processes blocks missing descriptions.
self.scheduler.add_job(
optimize_block_descriptions,
id="optimize_block_descriptions",
trigger="interval",
hours=24,
replace_existing=True,
max_instances=1,
jobstore=Jobstores.EXECUTION.value,
)
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)

View File

@@ -0,0 +1,5 @@
"""LLM registry public API."""
from .routes import router
__all__ = ["router"]

View File

@@ -0,0 +1,67 @@
"""Pydantic models for LLM registry public API."""
from __future__ import annotations
from typing import Any
import pydantic
class LlmModelCost(pydantic.BaseModel):
"""Cost configuration for an LLM model."""
unit: str # "RUN" or "TOKENS"
credit_cost: int = pydantic.Field(ge=0)
credential_provider: str
credential_id: str | None = None
credential_type: str | None = None
currency: str | None = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class LlmModelCreator(pydantic.BaseModel):
"""Represents the organization that created/trained the model."""
id: str
name: str
display_name: str
description: str | None = None
website_url: str | None = None
logo_url: str | None = None
class LlmModel(pydantic.BaseModel):
"""Public-facing LLM model information."""
slug: str
display_name: str
description: str | None = None
provider_name: str
creator: LlmModelCreator | None = None
context_window: int
max_output_tokens: int | None = None
price_tier: int # 1=cheapest, 2=medium, 3=expensive
is_recommended: bool = False
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
costs: list[LlmModelCost] = pydantic.Field(default_factory=list)
class LlmProvider(pydantic.BaseModel):
"""Provider with its enabled models."""
name: str
display_name: str
models: list[LlmModel] = pydantic.Field(default_factory=list)
class LlmModelsResponse(pydantic.BaseModel):
"""Response for GET /llm/models."""
models: list[LlmModel]
total: int
class LlmProvidersResponse(pydantic.BaseModel):
"""Response for GET /llm/providers."""
providers: list[LlmProvider]

View File

@@ -0,0 +1,141 @@
"""Public read-only API for LLM registry."""
import autogpt_libs.auth
import fastapi
from backend.data.llm_registry import (
RegistryModelCreator,
get_all_models,
get_enabled_models,
)
from backend.server.v2.llm import model as llm_model
router = fastapi.APIRouter(
prefix="/llm",
tags=["llm"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
def _map_creator(
creator: RegistryModelCreator | None,
) -> llm_model.LlmModelCreator | None:
"""Convert registry creator to API model."""
if not creator:
return None
return llm_model.LlmModelCreator(
id=creator.id,
name=creator.name,
display_name=creator.display_name,
description=creator.description,
website_url=creator.website_url,
logo_url=creator.logo_url,
)
@router.get("/models", response_model=llm_model.LlmModelsResponse)
async def list_models(
enabled_only: bool = fastapi.Query(
default=True, description="Only return enabled models"
),
):
"""
List all LLM models available to users.
Returns models from the in-memory registry cache.
Use enabled_only=true to filter to only enabled models (default).
"""
# Get models from in-memory registry
registry_models = get_enabled_models() if enabled_only else get_all_models()
# Map to API response models
models = [
llm_model.LlmModel(
slug=model.slug,
display_name=model.display_name,
description=model.description,
provider_name=model.provider_display_name,
creator=_map_creator(model.creator),
context_window=model.metadata.context_window,
max_output_tokens=model.metadata.max_output_tokens,
price_tier=model.metadata.price_tier,
is_recommended=model.is_recommended,
capabilities=model.capabilities,
costs=[
llm_model.LlmModelCost(
unit=cost.unit,
credit_cost=cost.credit_cost,
credential_provider=cost.credential_provider,
credential_id=cost.credential_id,
credential_type=cost.credential_type,
currency=cost.currency,
metadata=cost.metadata,
)
for cost in model.costs
],
)
for model in registry_models
]
return llm_model.LlmModelsResponse(models=models, total=len(models))
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
async def list_providers():
"""
List all LLM providers with their enabled models.
Groups enabled models by provider from the in-memory registry.
"""
# Get all enabled models and group by provider
registry_models = get_enabled_models()
# Group models by provider
provider_map: dict[str, list] = {}
for model in registry_models:
provider_key = model.metadata.provider
if provider_key not in provider_map:
provider_map[provider_key] = []
provider_map[provider_key].append(model)
# Build provider responses
providers = []
for provider_key, models in sorted(provider_map.items()):
# Use the first model's provider display name
display_name = models[0].provider_display_name if models else provider_key
providers.append(
llm_model.LlmProvider(
name=provider_key,
display_name=display_name,
models=[
llm_model.LlmModel(
slug=model.slug,
display_name=model.display_name,
description=model.description,
provider_name=model.provider_display_name,
creator=_map_creator(model.creator),
context_window=model.metadata.context_window,
max_output_tokens=model.metadata.max_output_tokens,
price_tier=model.metadata.price_tier,
is_recommended=model.is_recommended,
capabilities=model.capabilities,
costs=[
llm_model.LlmModelCost(
unit=cost.unit,
credit_cost=cost.credit_cost,
credential_provider=cost.credential_provider,
credential_id=cost.credential_id,
credential_type=cost.credential_type,
currency=cost.currency,
metadata=cost.metadata,
)
for cost in model.costs
],
)
for model in sorted(models, key=lambda m: m.display_name)
],
)
)
return llm_model.LlmProvidersResponse(providers=providers)

View File

@@ -7,6 +7,8 @@ from typing import TYPE_CHECKING
from backend.util.cache import cached, thread_cached
from backend.util.settings import Settings
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
settings = Settings()
if TYPE_CHECKING:
@@ -165,14 +167,20 @@ def get_openai_client() -> "AsyncOpenAI | None":
"""
Get a process-cached async OpenAI client for embeddings.
Returns None if API key is not configured.
Prefers openai_internal_api_key (direct OpenAI). Falls back to
open_router_api_key via OpenRouter's OpenAI-compatible endpoint.
Returns None if neither key is configured.
"""
from openai import AsyncOpenAI
api_key = settings.secrets.openai_internal_api_key
if not api_key:
return None
return AsyncOpenAI(api_key=api_key)
if settings.secrets.openai_internal_api_key:
return AsyncOpenAI(api_key=settings.secrets.openai_internal_api_key)
if settings.secrets.open_router_api_key:
return AsyncOpenAI(
api_key=settings.secrets.open_router_api_key,
base_url=OPENROUTER_BASE_URL,
)
return None
# ============ Notification Queue Helpers ============ #

View File

@@ -144,76 +144,106 @@ async def _resolve_host(hostname: str) -> list[str]:
return ip_addresses
async def validate_url(
url: str, trusted_origins: list[str]
async def validate_url_host(
url: str, trusted_hostnames: Optional[list[str]] = None
) -> tuple[URL, bool, list[str]]:
"""
Validates the URL to prevent SSRF attacks by ensuring it does not point
to a private, link-local, or otherwise blocked IP address — unless
Validates a (URL's) host string to prevent SSRF attacks by ensuring it does not
point to a private, link-local, or otherwise blocked IP address — unless
the hostname is explicitly trusted.
Hosts in `trusted_hostnames` are permitted without checks.
All other hosts are resolved and checked against `BLOCKED_IP_NETWORKS`.
Params:
url: A hostname, netloc, or URL to validate.
If no scheme is included, `http://` is assumed.
trusted_hostnames: A list of hostnames that don't require validation.
Raises:
ValueError:
- if the URL has a disallowed URL scheme
- if the URL/host string can't be parsed
- if the hostname contains invalid or unsupported (non-ASCII) characters
- if the host resolves to a blocked IP
Returns:
str: The validated, canonicalized, parsed URL
is_trusted: Boolean indicating if the hostname is in trusted_origins
ip_addresses: List of IP addresses for the host; empty if the host is trusted
1. The validated, canonicalized, parsed host/URL,
with hostname ASCII-safe encoding
2. Whether the host is trusted (based on the passed `trusted_hostnames`).
3. List of resolved IP addresses for the host; empty if the host is trusted.
"""
parsed = parse_url(url)
# Check scheme
if parsed.scheme not in ALLOWED_SCHEMES:
raise ValueError(
f"Scheme '{parsed.scheme}' is not allowed. Only HTTP/HTTPS are supported."
f"URL scheme '{parsed.scheme}' is not allowed; allowed schemes: "
f"{', '.join(ALLOWED_SCHEMES)}"
)
# Validate and IDNA encode hostname
if not parsed.hostname:
raise ValueError("Invalid URL: No hostname found.")
raise ValueError(f"Invalid host/URL; no host in parse result: {url}")
# IDNA encode to prevent Unicode domain attacks
try:
ascii_hostname = idna.encode(parsed.hostname).decode("ascii")
except idna.IDNAError:
raise ValueError("Invalid hostname with unsupported characters.")
raise ValueError(f"Hostname '{parsed.hostname}' has unsupported characters")
# Check hostname characters
if not HOSTNAME_REGEX.match(ascii_hostname):
raise ValueError("Hostname contains invalid characters.")
raise ValueError(f"Hostname '{parsed.hostname}' has unsupported characters")
# Check if hostname is trusted
is_trusted = ascii_hostname in trusted_origins
# If not trusted, validate IP addresses
ip_addresses: list[str] = []
if not is_trusted:
# Resolve all IP addresses for the hostname
ip_addresses = await _resolve_host(ascii_hostname)
# Block any IP address that belongs to a blocked range
for ip_str in ip_addresses:
if _is_ip_blocked(ip_str):
raise ValueError(
f"Access to blocked or private IP address {ip_str} "
f"for hostname {ascii_hostname} is not allowed."
)
# Reconstruct the netloc with IDNA-encoded hostname and preserve port
netloc = ascii_hostname
if parsed.port:
netloc = f"{ascii_hostname}:{parsed.port}"
return (
URL(
parsed.scheme,
netloc,
quote(parsed.path, safe="/%:@"),
parsed.params,
parsed.query,
parsed.fragment,
),
is_trusted,
ip_addresses,
# Re-create parsed URL object with IDNA-encoded hostname
parsed = URL(
parsed.scheme,
(ascii_hostname if parsed.port is None else f"{ascii_hostname}:{parsed.port}"),
quote(parsed.path, safe="/%:@"),
parsed.params,
parsed.query,
parsed.fragment,
)
is_trusted = trusted_hostnames and any(
matches_allowed_host(parsed, allowed)
for allowed in (
# Normalize + parse allowlist entries the same way for consistent matching
parse_url(w)
for w in trusted_hostnames
)
)
if is_trusted:
return parsed, True, []
# If not allowlisted, go ahead with host resolution and IP target check
return parsed, False, await _resolve_and_check_blocked(ascii_hostname)
def matches_allowed_host(url: URL, allowed: URL) -> bool:
if url.hostname != allowed.hostname:
return False
# Allow any port if not explicitly specified in the allowlist
if allowed.port is None:
return True
return url.port == allowed.port
async def _resolve_and_check_blocked(hostname: str) -> list[str]:
"""
Resolves hostname to IPs and raises ValueError if any resolve to
a blocked network. Returns the list of resolved IP addresses.
"""
ip_addresses = await _resolve_host(hostname)
for ip_str in ip_addresses:
if _is_ip_blocked(ip_str):
raise ValueError(
f"Access to blocked or private IP address {ip_str} "
f"for hostname {hostname} is not allowed."
)
return ip_addresses
def parse_url(url: str) -> URL:
"""Canonicalizes and parses a URL string."""
@@ -352,7 +382,7 @@ class Requests:
):
self.trusted_origins = []
for url in trusted_origins or []:
hostname = urlparse(url).hostname
hostname = parse_url(url).netloc # {host}[:{port}]
if not hostname:
raise ValueError(f"Invalid URL: Unable to determine hostname of {url}")
self.trusted_origins.append(hostname)
@@ -450,7 +480,7 @@ class Requests:
data = form
# Validate URL and get trust status
parsed_url, is_trusted, ip_addresses = await validate_url(
parsed_url, is_trusted, ip_addresses = await validate_url_host(
url, self.trusted_origins
)
@@ -503,7 +533,6 @@ class Requests:
json=json,
**kwargs,
) as response:
if self.raise_for_status:
try:
response.raise_for_status()

View File

@@ -1,7 +1,7 @@
import pytest
from aiohttp import web
from backend.util.request import pin_url, validate_url
from backend.util.request import pin_url, validate_url_host
@pytest.mark.parametrize(
@@ -60,9 +60,9 @@ async def test_validate_url_no_dns_rebinding(
):
if should_raise:
with pytest.raises(ValueError):
await validate_url(raw_url, trusted_origins)
await validate_url_host(raw_url, trusted_origins)
else:
validated_url, _, _ = await validate_url(raw_url, trusted_origins)
validated_url, _, _ = await validate_url_host(raw_url, trusted_origins)
assert validated_url.geturl() == expected_value
@@ -101,10 +101,10 @@ async def test_dns_rebinding_fix(
if expect_error:
# If any IP is blocked, we expect a ValueError
with pytest.raises(ValueError):
url, _, ip_addresses = await validate_url(hostname, [])
url, _, ip_addresses = await validate_url_host(hostname)
pin_url(url, ip_addresses)
else:
url, _, ip_addresses = await validate_url(hostname, [])
url, _, ip_addresses = await validate_url_host(hostname)
pinned_url = pin_url(url, ip_addresses).geturl()
# The pinned_url should contain the first valid IP
assert pinned_url.startswith("http://") or pinned_url.startswith("https://")

View File

@@ -89,6 +89,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
le=500,
description="Thread pool size for FastAPI sync operations. All sync endpoints and dependencies automatically use this pool. Higher values support more concurrent sync operations but use more memory.",
)
ollama_host: str = Field(
default="localhost:11434",
description="Default Ollama host; exempted from SSRF checks.",
)
pyro_host: str = Field(
default="localhost",
description="The default hostname of the Pyro server.",
@@ -363,23 +367,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="Whether to mark failed scans as clean or not",
)
agentgenerator_host: str = Field(
default="",
description="The host for the Agent Generator service (empty to use built-in)",
)
agentgenerator_port: int = Field(
default=8000,
description="The port for the Agent Generator service",
)
agentgenerator_timeout: int = Field(
default=1800,
description="The timeout in seconds for Agent Generator service requests (includes retries for rate limits)",
)
agentgenerator_use_dummy: bool = Field(
default=False,
description="Use dummy agent generator responses for testing (bypasses external service)",
)
enable_example_blocks: bool = Field(
default=False,
description="Whether to enable example blocks in production",

View File

@@ -0,0 +1,3 @@
-- AlterTable
ALTER TABLE "AgentBlock" ADD COLUMN "description" TEXT,
ADD COLUMN "optimizedDescription" TEXT;

Some files were not shown because too many files have changed in this diff Show More